Tensorflow了解tf.train.shuffle_batch

时间:2016-06-25 20:59:01

标签: machine-learning tensorflow mathematical-optimization gradient-descent

我有一个训练数据文件,大约100K行,而且我在每个训练步骤中都运行一个简单的tf.train.GradientDescentOptimizer。该设置基本上直接来自Tensorflow的MNIST示例。代码转载如下:

x = tf.placeholder(tf.float32, [None, 21])
W = tf.Variable(tf.zeros([21, 2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 2])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

鉴于我正在从文件中读取训练数据,我正在使用tf.train.string_input_producertf.decode_csv来读取csv中的行,然后使用tf.train.shuffle_batch来创建我当时的批次火车上。

我对tf.train.shuffle_batch的参数应该感到困惑。我阅读了Tensorflow的文档,但我仍然不确定“最佳”batch_size,capacity和min_after_dequeue值是什么。任何人都可以帮助我解释如何为这些参数选择合适的值,或者将我链接到我可以了解更多信息的资源?谢谢 -

以下是API链接:https://www.tensorflow.org/versions/r0.9/api_docs/python/io_ops.html#shuffle_batch

1 个答案:

答案 0 :(得分:2)

有一些关于

使用的线程数

https://www.tensorflow.org/versions/r0.9/how_tos/reading_data/index.html#batching

不幸的是,我不认为批量大小有一个简单的答案。 网络的有效批量大小取决于许多细节 关于网络。在实践中,如果您关心最佳性能 你需要做一堆反复试验(也许是开始 从类似网络使用的值。)