关于使用tf.train.shuffle_batch()创建批次

时间:2016-09-02 02:48:28

标签: tensorflow

Tensorflow tutorial中,它提供了有关tf.train.shuffle_batch()的以下示例:

# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch(
     [single_image, single_label],
     batch_size=32,
     num_threads=4,
     capacity=50000,
     min_after_dequeue=10000)

我不清楚capacitymin_after_dequeue的含义。在此示例中,它分别设置为5000010000。这种设置的逻辑是什么,或者是什么意思。如果输入有200个图像和200个标签,会发生什么?

1 个答案:

答案 0 :(得分:24)

tf.train.shuffle_batch()函数在内部使用tf.RandomShuffleQueue来累积批量的batch_size元素,这些元素从当前队列中的元素中随机均匀采样。

许多训练算法,例如TensorFlow用于优化神经网络的基于随机梯度下降的算法,依赖于从整个训练集中随机均匀地采样记录。但是,将整个训练集加载到内存中(以便从中进行采样)并不总是切实可行,因此tf.train.shuffle_batch()提供了一种折衷:它填充了min_after_dequeue和{{1}之间的内部缓冲区元素,并从该缓冲区随机均匀地采样。对于许多培训过程,这可以提高模型的准确性并提供足够的随机化。

capacitymin_after_dequeue参数会对培训效果产生间接影响。设置较大的capacity值会延迟训练的开始,因为TensorFlow必须在训练开始之前处理至少那么多元素。 min_after_dequeue是输入管道将消耗的内存量的上限:将此设置得太大可能导致训练过程耗尽内存(并且可能开始交换,这将损害训练吞吐量)。

如果数据集只有200个图像,则可以轻松地将整个数据集加载到内存中。 capacity效率很低,因为它会在tf.train.shuffle_batch()中多次将每个图像排列并标记。在这种情况下,您可能会发现使用tf.train.slice_input_producer()tf.train.batch()执行以下操作会更有效:

tf.RandomShuffleQueue