我正在使用以下模型处理多标签问题。
clf = SGDClassifier(loss='hinge', penalty='l2',
alpha=1e-3, random_state=42,
max_iter=5, tol=None)
下面是代码段。
mlb = MultiLabelBinarizer()
vect = TfidfVectorizer()
X, Y = vect.fit_transform(X_raw), mlb.fit_transform(Y_raw)
X_Train, X_Test, Y_Train, y_test = train_test_split(X, Y, random_state=0, test_size=0.33, shuffle=True)
clf.fit(X_Train,Y_Train)
predict = clf.predict(X_Test)
predict_label = mlb.inverse_transform(predict)
请注意,X.shape
是(829,565)
,而Y.shape
是(829,251)
现在我收到以下错误
File "<ipython-input-101-5e447893b03c>", line 1, in <module>
runfile('D:/Testing/model_build_mlb.py', wdir='D:/Testing')
File "C:\Users\ce\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 668, in runfile
execfile(filename, namespace)
File "C:\Users\ce\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 108, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "D:/Testing/model_build_mlb.py", line 1715, in <module>
clf.fit(X_Train,Y_Train)
File "C:\Users\ce\Anaconda3\lib\site-packages\sklearn\linear_model\stochastic_gradient.py", line 586, in fit
sample_weight=sample_weight)
File "C:\Users\ce\Anaconda3\lib\site-packages\sklearn\linear_model\stochastic_gradient.py", line 418, in _fit
X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C")
File "C:\Users\ce\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 578, in check_X_y
y = column_or_1d(y, warn=True)
File "C:\Users\ce\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 614, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (555, 251)
是因为型号吗?还是其他? 请帮忙。提前致谢。