我训练了mnist_784中的数据(带有标签,是否为9), 变量some_digit是一个training_data,目标为9,我使用 预测变量的模型:some_digit,但是,predict不是 对,
# import the data
mnist = fetch_openml('mnist_784')
X, y = mnist['data'], mnist['target']
some_digit = X[36000]
print('some_digit is ', y[36000]) # some_digit is 9
# split the data and make it random
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000],
y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
# true for y_train='9', false for all others
y_train_9 = (y_train == '9')
y_test_9 = (y_test == '9')
# train the data with a label whether it is '9'
sgd_clf = SGDClassifier(random_state=0)
sgd_clf.fit(X_train, y_train_9)
print('I wonder whether [some_digit] is 9 with the model I trained ',
sgd_clf.predict([some_digit]))
# I wonder whether [some_digit] is 9 with the model I trained [False]
# the predict is not right
# however the correct_rate is very high since I got 0.9472, 0.9378,
0.9507
skfolds = StratifiedKFold(n_splits=3, random_state=0)
for train_index, test_index in skfolds.split(X_train, y_train_9):
clone_clf = clone(sgd_clf)
X_train_folds = X_train[train_index]
y_train_folds = y_train_9[train_index]
X_test_fold = X_train[test_index]
y_test_fold = y_train_9[test_index]
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
预期的[some_digit]的预测为true,因为变量some_digit 是9,我想知道为什么它不是真的。