scikit-learn/scikit-learn

Bug: BaggingClassifier for multiclass usage

Open

#8,409 opened on 2017年2月20日

GitHub で見る
 (3 comments) (0 reactions) (0 assignees)Python (66,084 stars) (27,020 forks)batch import
BugNeeds Reproducible Codehelp wantedmodule:ensemble

説明

Hi,

The BaggingClassifier does not check if the number of classes of the random drawn samples for one of its estimators matches the number of classes in the dataset, resulting in an error message when its predict method is used:

---------------------------------------------------------------------------
Sub-process traceback:
---------------------------------------------------------------------------
ValueError                                         Thu Feb 16 11:17:13 2017
PID: 25287                                    Python 2.7.5: /usr/bin/python
...........................................................................
.local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py in __call__(self=<sklearn.externals.joblib.parallel.BatchedCalls object>)
    126     def __init__(self, iterator_slice):
    127         self.items = list(iterator_slice)
    128         self._size = len(self.items)
    129
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
        func = <function _parallel_predict_proba>
        args = ([SVC(C=100, cache_size=200, class_weight=None, co...5695085, shrinking=True,
  tol=0.001, verbose=10), SVC(C=100, cache_size=200, class_weight=None, co...8919740, shrinking=True,
  tol=0.001, verbose=10)], [array([   0,    1,    2, ..., 2045, 2046, 2047]), array([   0,    1,    2, ..., 2045, 2046, 2047])], memmap([[  1.21488877e-01,  -6.84937861e-01,  -5...022708e-05,   1.09372859e-06,  -8.55540498e-06]]), 1000)
        kwargs = {}
        self.items = [(<function _parallel_predict_proba>, ([SVC(C=100, cache_size=200, class_weight=None, co...5695085, shrinking=True,
  tol=0.001, verbose=10), SVC(C=100, cache_size=200, class_weight=None, co...8919740, shrinking=True,
  tol=0.001, verbose=10)], [array([   0,    1,    2, ..., 2045, 2046, 2047]), array([   0,    1,    2, ..., 2045, 2046, 2047])], memmap([[  1.21488877e-01,  -6.84937861e-01,  -5...022708e-05,   1.09372859e-06,  -8.55540498e-06]]), 1000), {})]
    132
    133     def __len__(self):
    134         return self._size
    135

...........................................................................
.local/lib/python2.7/site-packages/sklearn/ensemble/bagging.py in _parallel_predict_proba(estimators=[SVC(C=100, cache_size=200, class_weight=None, co...5695085, shrinking=True,
  tol=0.001, verbose=10), SVC(C=100, cache_size=200, class_weight=None, co...8919740, shrinking=True,
  tol=0.001, verbose=10)], estimators_features=[array([   0,    1,    2, ..., 2045, 2046, 2047]), array([   0,    1,    2, ..., 2045, 2046, 2047])], X=memmap([[  1.21488877e-01,  -6.84937861e-01,  -5...022708e-05,   1.09372859e-06,  -8.55540498e-06]]), n_classes=1000)
    130     for estimator, features in zip(estimators, estimators_features):
    131         if hasattr(estimator, "predict_proba"):
    132             proba_estimator = estimator.predict_proba(X[:, features])
    133
    134             if n_classes == len(estimator.classes_):
--> 135                 proba += proba_estimator
        proba = array([[ 0.00130233,  0.00013968,  0.00144125, ....  0.00016293,
         0.00010567,  0.00053245]])
        proba_estimator = array([[  1.02577963e-03,   3.75469340e-04,   9....362413e-05,   1.45631109e-04,   3.04322015e-04]])
    136
    137             else:
    138                 proba[:, estimator.classes_] += \
    139                     proba_estimator[:, range(len(estimator.classes_))]

ValueError: operands could not be broadcast together with shapes (8009,1000) (8009,999) (8009,1000)
___________________________________________________________________________

It would be nice to get a warning message, if the number of classes used to train an estimator in the BaggingClassifier would not match the overall number of classes in the trainingsset.

コントリビューターガイド