在sklearn数字数据集上输入形状错误

时间:2018-05-14 03:39:23

标签: python machine-learning scikit-learn

import sklearn
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn import svm

digits = datasets.load_digits()

clf = svm.SVC(gamma=0.01,C= 100)

x = digits.data[:-10]
y = digits.data[:-10]

clf.fit(x,y)

print ("prediction:",clf.predict(digits.data[-1]))

plt.imshow(digits.image[-1],cmap = plt.cm.gray_r, interpolation ="nearest")
plt.show

我收到错误

Traceback (most recent call last):
  File "python", line 14, in <module>
ValueError: bad input shape (1787, 64)

我不确定阵列的形状应该是什么以及如何?!!

有人可以帮到这里!感谢

2 个答案:

答案 0 :(得分:0)

您应该将digits.target用作y,而不是digits.data。 SVM仅接受一维预测变量,该变量暗示您将错误的值传递给它。 targetx中8x8图像的值(0..9)。

答案 1 :(得分:-1)

Arya McCarthy关于使用digits.target作为y的建议是正确的。因此,将digits.data声明为特征变量xdigits.target作为目标变量y是正确的。请试试我修改的代码如下:

import sklearn
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn import svm

digits = datasets.load_digits()

clf = svm.SVC(gamma=0.01,C= 100)

x = digits.data
y = digits.target

clf.fit(x,y)
print ("prediction:",clf.predict(digits.data[-1].reshape(1, -1)))
plt.imshow(digits.images[-1],cmap = plt.cm.gray_r, interpolation ="nearest")
plt.show()

以下是我运行代码的结果

enter image description here