这个tensorflow代码有什么问题?我似乎迟到了看错。它没有收敛。它停止了2.30。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W1 = tf.Variable(tf.zeros([784, 100]))
b1 = tf.Variable(tf.zeros([100]))
W2 = tf.Variable(tf.zeros([100, 20]))
b2 = tf.Variable(tf.zeros([20]))
W3 = tf.Variable(tf.zeros([20, 10]))
b3 = tf.Variable(tf.zeros([10]))
y1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
y2 = tf.nn.relu(tf.add(tf.matmul(y1, W2), b2))
y3 = tf.nn.softmax(tf.add(tf.matmul(y2, W3), b3))
sess = tf.InteractiveSession()
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y3), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()
init = tf.global_variables_initializer()
sess.run(init)
for _ in range(10000):
batch_xs, batch_ys = mnist.train.next_batch(100)
print(sess.run(cross_entropy, feed_dict={x: batch_xs, y_: batch_ys}))
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
谢谢!
答案 0 :(得分:0)
我可以看到一些应该解决的问题:
使用零(tf.zeros
)初始化的网络无法学习,原因有两个:
我还建议使用内置的tf.losses.softmax_cross_entropy
而不是自己动手。这通常是一个好主意,因为它最大限度地减少了在整个过程中犯错误的可能性。 :)