我正在使用Scikit-Learn和Tensorflow进行O'Reilly的动手机器学习。
我正在训练MNIST数据集上的分类器,但出现错误
ValueError: The number of classes has to be greater than one; got 1 class
这是我的代码
mnist = fetch_openml('mnist_784', version=1, cache=True)
X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
y_train_5 = (y_train == 9)
y_test_5 = (y_test == 9)
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
我已经三遍检查了我的代码,但仍然不确定发生了什么。
答案 0 :(得分:3)
sklearn
中来自MNIST数据集的标签包含字符串,而不是整数。因此,设置
y_train_5 = (y_train == '9')
y_test_5 = (y_test == '9')
当您使用整数进行检查时,所有内容都为False
,Python警告您只有一个类。
答案 1 :(得分:0)
此过程完全正确,只需将数字放入字符串中,因为scikit中的标签需要字符串。
var debug: some View {
MyViewWithError(property: self.property)
}