我正在尝试使用Tensorflow运行MNIST数据集。这是我的代码
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
X_train = np.array(mnist.train.images, 'float')
X_test = np.array(mnist.test.images, 'float')
y_train = np.array(mnist.train.images, 'int32')
y_test = np.array(mnist.test.images, 'int32')
# Specify feature
feature_columns = [tf.contrib.layers.real_valued_column('', dimension=784)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[200, 100, 60, 30],
n_classes=10,
model_dir="./output"
)
# Fit model.
classifier.fit(X_train, y_train, batch_size=100, steps=1000)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(X_test, y_test)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
然而,我一直收到错误:
ValueError:无法挤压暗淡[1],预期维度为1,得到784 for' dnn / multi_class_head / softmax_cross_entropy_loss / Squeeze' (OP: '挤压')输入形状:[?,784]。
回溯是引起我注意第31行,这是我在分类器上调用fit()的地方,但我无法弄清楚原因。
答案 0 :(得分:0)
这应该可以正常工作。版本 =' 1.1.0'并使用python 3.6。输入数据的维度可能存在一些问题,但可以从此向后工作。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
mnist_data = "save/here/mnist"
MNIST_DATASET = input_data.read_data_sets(mnist_data)
train_data = np.array(MNIST_DATASET.train.images, 'float32')
train_target = np.array(MNIST_DATASET.train.labels, 'int64')
test_data = np.array(MNIST_DATASET.test.images, 'float32')
test_target = np.array(MNIST_DATASET.test.labels, 'int64')
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=784)]
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns
,n_classes=10
,hidden_units=[128, 32]
)
classifier.fit(train_data, train_target, steps=5)
accuracy_score = classifier.evaluate(test_data, test_target, steps=5)['accuracy']
print("accuracy: ", 100*accuracy_score,"%")
输出:
WARNING:tensorflow:Skipping summary for global_step, must be a float or np.float32.
accuracy: 44.3899989128 %