scikit-learn:多标签交叉验证在v0.19中不再起作用

时间:2017-10-12 06:17:14

标签: python scikit-learn

亲爱的Stack Overflow用户,

我使用sklearn训练用于文本挖掘的多标签SVM(具有概率)。对于每个条目,我没有单个目标标签,而是列表。这些目标通过MultiLabelBinarizer进行转换:

vect = TfidfVectorizer()
x_train = vect.fit_transform(training_texts)

mlb = MultiLabelBinarizer()
training_targets_mlb = mlb.fit_transform(training_targets)

clf = OneVsRestClassifier(SVC(kernel='linear', probability=True)
clf.fit(x_train, training_targets_mlb)

将sklearn升级到最新版本0.19之后,上面的代码仍然可以正常工作,但是下面的交叉验证代码(在我认为的0.18中有效)现在会引发错误:

from sklearn.model_selection import cross_val_predict
cv_scores = cross_val_predict(estimator=clf,
                              X=vect.fit_transform(texts),
                              y=mlb.fit_transform(targets),
                              cv=sklearn.model_selection.KFold(shuffle=True, n_splits=5),
                              method='predict_proba')

错误:

File "/usr/local/lib/python2.7/dist-packages/sklearn/model_selection/_validation.py", line 647, in cross_val_predict
    y = le.fit_transform(y)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/preprocessing/label.py", line 111, in fit_transform
    y = column_or_1d(y, warn=True)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.py", line 583, in column_or_1d
    raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (18165, 25)

这是预期的行为,即我不应该在这里传递数组吗?如果是这样,我将如何进行交叉验证?

感谢您的帮助!

0 个答案:

没有答案