TensorFlow对简单网络的表现不佳吗?

时间:2016-04-21 20:45:15

标签: python performance machine-learning tensorflow

我已经在各种框架中尝试过简单的基础(入门级教程级别)神经网络,但对于我在TensorFlow中看到的性能感到困惑。

例如,来自Michael Nielsen's tutorial的简单网络(在具有30个隐藏节点的网络中使用L2随机梯度下降的MNIST数字识别)执行得更糟(每个时期需要大约8倍,具有所有相同的参数)而不是略微适应(使用one of the tutorial exercises中建议的小批量矢量化)Nielsen's basic NumPy code

在单个CPU上运行的TensorFlow是否始终执行此操作?我应该调整设置以提高性能吗?或者TensorFlow是否真的只有更复杂的网络或学习方式才能发挥作用,因此预计这种简单的玩具箱不会表现良好?

from __future__ import (absolute_import, print_function, division, unicode_literals)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time


def weight_variable(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=0.1))


def bias_variable(shape):
    return tf.Variable(tf.constant(0.1, shape=shape))


mnist = input_data.read_data_sets("./data/", one_hot=True)

sess = tf.Session()

# Inputs and outputs
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

# Model parameters
W1 = weight_variable([784, 30])
b1 = bias_variable([30])
o1 = tf.nn.sigmoid(tf.matmul(x, W1) + b1, name='o1')
W2 = weight_variable([30, 10])
b2 = bias_variable([10])
y = tf.nn.softmax(tf.matmul(o1, W2) + b2, name='y')

sess.run(tf.initialize_all_variables())

loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
loss += 0.1/1000 * (tf.nn.l2_loss(W1) + tf.nn.l2_loss(W2))

train_step = tf.train.GradientDescentOptimizer(0.15).minimize(loss)

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32))


for ep in range(30):
    for mb in range(int(len(mnist.train.images)/40)):
        batch_xs, batch_ys = mnist.train.next_batch(40)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

1 个答案:

答案 0 :(得分:1)

Yes, I would expect that hand coded specialized simple networks running on CPU would run faster than tensorflow ones. The reason is usually connected to the graph evaluation system that tensorflow uses.

The benefit of using tensorflow is when you have much more complex algorithms and you want to be able to test for the correctness first and then be able to easily port it to use more machines and more processing units.

For example one thing you can try is to run your code on a machine that has a GPU and see that without changing anything in your code you would get a speed up, maybe faster than the hand coded example you linked. You can see that the hand written code would require considerable effort to be ported to GPU.