tensorflow mnist神经网络模型精度低

时间:2019-06-21 18:10:36

标签: python python-3.x tensorflow

此基本mnist模型的准确度仅为9%。

有人可以帮助我了解我在这里做错了什么吗

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.reshape(x_train, (60000, 784))
x_test  = np.reshape(x_test, (10000, 784))
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
w = tf.Variable(tf.zeros((784, 10)))
b = tf.Variable(tf.zeros((10)))
y_hat = tf.nn.softmax(tf.matmul(x, w) + b)

cross_entropy = tf.reduce_mean(-tf.reduce_sum((y * tf.log(y_hat)), axis=1))
training_gd = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.initializers.global_variables().run()

for _ in range(10000):
    indices = np.random.randint(0, len(x_train), 100)
    batch_xs, batch_ys =  x_train[indices], y_train[indices]
    sess.run(training_gd, feed_dict={x: batch_xs, y: batch_ys})

correct_prediction = tf.equal(tf.argmax(y_hat, axis=1), tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: x_test, y: y_test}))
0.098

1 个答案:

答案 0 :(得分:1)

使用以下方法计算交叉熵:

cross_entropy = tf.reduce_mean(-tf.reduce_sum((y * tf.log(y_hat)), axis=1))

在数值上不稳定。

您可以将0到1之间的图像标准化。

x_train = np.reshape(x_train, (60000, 784)) / 255.0
x_test  = np.reshape(x_test, (10000, 784)) / 255.0

或使用TensorFlow API计算交叉熵

logits = tf.matmul(x, w) + b
y_hat = tf.nn.softmax(logits)

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)