我开始使用tensorflow,我试图读取MNIST的手写信件。我的代码中有错误,但我不明白为什么。我发现了一个与此类似的帖子,但我对此代码有同样的错误。 (此主题的链接TensorFlow Cannot feed value of shape (100, 784) for Tensor 'Placeholder:0')
enter code here import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
X = tf.placeholder(tf.float32,[None,28,28,1])
W = tf.Variable(tf.zeros([784,10]))
B = tf.Variable(tf.zeros([10]))
init = tf.global_variables_initializer()
#Model
Y = tf.nn.softmax(tf.matmul(tf.reshape(X,[-1,784]),W)+B)
#Placeholder for correct answer
Y_ = tf.placeholder(tf.float32,[None,10])
#Calcul de l'erreur
cross_entropy = -tf.reduce_sum(Y_ * tf.log(Y))
# pourcentage de bonne réponse
is_correct = tf.equal(tf.argmax(Y,1),tf.argmax(Y_,1))
accuracy = tf.reduce_mean(tf.cast(is_correct,tf.float32))
#Regression linéaire
optimizer = tf.train.GradientDescentOptimizer(0.003)
train_step = optimizer.minimize(cross_entropy)
#Training process
sess = tf.Session()
sess.run(init)
for i in range(1000):
#On charge les images
batch_X,batch_Y = mnist.train.next_batch(100)
batch_X = np.reshape(batch_X, (-1, 28, 28, 1))
train_data = {X: batch_X, Y_: batch_Y}
#train
sess.run(train_step, feed_dict = train_data)
#success ?
a,c = sess.run([accuracy,cross_entropy],feed_dict = train_data)
#success on train data ?
test_data = {X:mnist.test.images, Y_:mnist.test.labels}
a,c = sess.run([accuracy, cross_entropy],feed=test_data)
答案 0 :(得分:1)
将最后一行更改为:
test_images = np.reshape(mnist.test.images, (-1, 28, 28, 1))
test_data = {X:mnist.test.images, Y_:test_images}
a,c = sess.run([accuracy, cross_entropy],feed_dict=test_data)