简单网络出错

时间:2017-12-07 22:54:11

标签: tensorflow

这个tensorflow代码有什么问题?我似乎迟到了看错。它没有收敛。它停止了2.30。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


x = tf.placeholder(tf.float32, [None, 784])
W1 = tf.Variable(tf.zeros([784, 100]))
b1 = tf.Variable(tf.zeros([100]))

W2 = tf.Variable(tf.zeros([100, 20]))
b2 = tf.Variable(tf.zeros([20]))

W3 = tf.Variable(tf.zeros([20, 10]))
b3 = tf.Variable(tf.zeros([10]))

y1 = tf.nn.relu(tf.add(tf.matmul(x, W1), b1))
y2 = tf.nn.relu(tf.add(tf.matmul(y1, W2), b2))
y3 = tf.nn.softmax(tf.add(tf.matmul(y2, W3), b3))
sess = tf.InteractiveSession()

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y3), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()
init = tf.global_variables_initializer()
sess.run(init)

for _ in range(10000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    print(sess.run(cross_entropy, feed_dict={x: batch_xs, y_: batch_ys}))
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

谢谢!

1 个答案:

答案 0 :(得分:0)

我可以看到一些应该解决的问题:

  • 对于随机梯度下降,0.5的学习率非常大。如果网络没有训练,您可以尝试使用不同的值,通常在[1e-2,1e-5]范围内。
  • 使用零(tf.zeros)初始化的网络无法学习,原因有两个:

    1. 参数值之间没有任何差异,渐变在它们之间均匀分配,这意味着它们都学会了相同的值。
    2. 由于在反向传播期间梯度乘以权重,因此结果值将始终等于零 - 意味着权重值没有变化。

我还建议使用内置的tf.losses.softmax_cross_entropy而不是自己动手。这通常是一个好主意,因为它最大限度地减少了在整个过程中犯错误的可能性。 :)