scikit-learn/scikit-learn

Error in using multi-label classification in partial_fit() in OvR

Open

#8,381 创建于 2017年2月17日

在 GitHub 查看
 (18 评论) (6 反应) (0 负责人)Python (66,084 star) (27,020 fork)batch import
Bughelp wantedmodule:multiclass

描述

Description

When using OneVsRestClassifier() with partial_fit() method, errors are thrown. When using fit(), no errors are thrown and everything works.

Steps/Code to Reproduce

from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
import numpy as np

categories = ['a','b','c']
X = ["This is a test", "This is another attempt", "And this is a test too!"]
Y = [['a', 'b'],['b', 'c'],['a', 'b']] 

mlb = MultiLabelBinarizer(classes=categories)
vectorizer = HashingVectorizer(decode_error='ignore', n_features=2 ** 18,         non_negative=True)
clf = OneVsRestClassifier(MultinomialNB(alpha=0.01))

X_train = vectorizer.fit_transform(X)
Y_train = mlb.fit_transform(Y)

- Case1
clf.partial_fit(X_train, Y_train, categories)
- Case2
clf.partial_fit(X_train, Y_train, mlb.transform(Y))

Description of code

  • Case1 Using classes=categories without transforming partial_fit(X_train, Y_train, classes=categories)

      ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    
  • Case2 Using classes=mlb.transform(categories) i.e. after transforming from same multilabelbinarizer partial_fit(X_train, Y_train, classes=mlb.transform(categories))

       ValueError: The object was not fitted with multilabel input.
    

Expected Results

No error is thrown as when using fit().

Actual Results

  • Case1

Traceback (most recent call last): File "/path_to_module/Check.py", line 18, in clf.partial_fit(X_train, Y_train, categories) File "/library/python2.7/dist-packages/sklearn/multiclass.py", line 260, in partial_fit if np.setdiff1d(y, self.classes_): ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

  • Case2

Traceback (most recent call last): File "/path_to_module/Check.py", line 18, in clf.partial_fit(X_train, Y_train, mlb.transform(Y)) File "/library/python2.7/dist-packages/sklearn/multiclass.py", line 265, in partial_fit Y = self.label_binarizer_.transform(y) File "/library/python2.7/dist-packages/sklearn/preprocessing/label.py", line 329, in transform raise ValueError("The object was not fitted with multilabel" ValueError: The object was not fitted with multilabel input.

Observation

Versions

Linux-3.16.0-77-generic-x86_64-with-Ubuntu-14.04-trusty ('Python', '2.7.6 (default, Oct 26 2016, 20:30:19) \n[GCC 4.8.4]') ('NumPy', '1.12.0') ('SciPy', '0.18.1') ('Scikit-Learn', '0.18.1')

贡献者指南