使用sklearn预测多标签数据

时间:2016-05-06 12:48:06

标签: python scikit-learn

根据文档,OneVsRest分类器支持多标签分类:http://scikit-learn.org/stable/modules/multiclass.html#multilabel-learning

这是我尝试运行的代码:

from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.svm import SVC

x = [[1,2,3],[3,3,2],[8,8,7],[3,7,1],[4,5,6]]
y = [['bar','foo'],['bar'],['foo'],['foo','jump'],['bar','fox','jump']]

y_enc = MultiLabelBinarizer().fit_transform(y)

train_x, train_y, test_x, test_y = train_test_split(x, y_enc, test_size=0.33)

clf = OneVsRestClassifier(SVC())
clf.fit(train_x, train_y)
predictions = clf.predict_proba(test_x)

my_metrics = metrics.classification_report( test_y, predictions)
print my_metrics

我收到以下错误:

Traceback (most recent call last):
  File "multilabel.py", line 178, in <module>
    clf.fit(train_x, train_y)
  File "/sklearn/lib/python2.6/site-packages/sklearn/multiclass.py", line 277, in fit
    Y = self.label_binarizer_.fit_transform(y)
  File "/sklearn/lib/python2.6/site-packages/sklearn/base.py", line 455, in fit_transform
    return self.fit(X, **fit_params).transform(X)
  File "/sklearn/lib/python2.6/site-packages/sklearn/preprocessing/label.py", line 302, in fit
    raise ValueError("Multioutput target data is not supported with "
ValueError: Multioutput target data is not supported with label binarization

不使用MultiLabelBinarizer会产生相同的错误,因此我假设不是问题所在。有谁知道如何将这个分类器用于多标签数据?

3 个答案:

答案 0 :(得分:8)

您的train_test_split()输出不正确。改变这一行:

train_x, train_y, test_x, test_y = train_test_split(x, y_enc, test_size=0.33)

对此:

train_x, test_x, train_y, test_y = train_test_split(x, y_enc, test_size=0.33)

此外,要使用概率而不是类预测,您需要将SVC()更改为SVC(probability = True)并将clf.predict_proba更改为clf.predict

全部放在一起:

from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.svm import SVC


x = [[1,2,3],[3,3,2],[8,8,7],[3,7,1],[4,5,6]]
y = [['bar','foo'],['bar'],['foo'],['foo','jump'],['bar','fox','jump']]

mlb = MultiLabelBinarizer()
y_enc = mlb.fit_transform(y)

train_x, test_x, train_y, test_y = train_test_split(x, y_enc, test_size=0.33)

clf = OneVsRestClassifier(SVC(probability=True))
clf.fit(train_x, train_y)
predictions = clf.predict(test_x)

my_metrics = metrics.classification_report( test_y, predictions)
print my_metrics

这让我在运行时没有错误。

答案 1 :(得分:3)

我也经历过使用OneVsRestClassifier的“ValueError:标签二值化不支持多输出目标数据”。我的问题是由训练数据类型为“list”引起的,在使用np.array()进行投射后,它可以正常工作。

答案 2 :(得分:0)

对我来说,将train_xtrain_ytext_xtest_y包装在np.array()中可以解决此问题。