scikit multilabel分类:ValueError:输入形状不好

时间:2013-12-02 19:03:42

标签: machine-learning classification scikit-learn stochastic-process

我认为SGDClassifier() loss='log'支持多标签分类,我不必使用OneVsRestClassifier。 Check this

现在,我的数据集很大,我正在使用HashingVectorizer并将结果作为输入传递给SGDClassifier。我的目标有42048个功能。

当我运行时,如下:

clf.partial_fit(X_train_batch, y)

我得到:ValueError: bad input shape (300000, 42048)

我还使用了类作为参数,但仍然存在同样的问题。

clf.partial_fit(X_train_batch, y, classes=np.arange(42048))

在SGDClassifier的文档中,它说y : numpy array of shape [n_samples]

1 个答案:

答案 0 :(得分:4)

不,SGDClassifier 进行多标记分类 - 它会进行多类分类,这是一个不同的问题,尽管两者都是使用一个解决的 - vs-all减少问题。

然后,SGD和OneVsRestClassifier.fit都不接受y的稀疏矩阵。前者需要一系列标签,正如您已经发现的那样。为了多标记的目的,后者需要一个标签列表列表,例如

y = [[1], [2, 3], [1, 3]]

表示X[0]标签为1,X[1]标签为{2,3}X[2]标签为{1,3}