Tensorflow批处理参数

时间:2016-09-14 13:46:40

标签: machine-learning tensorflow

批处理看起来比feed_dicts更清晰,所以我试图理解Tensorflow中的批处理。

以下代码块是否会在批处理中创建32个相同的图像,然后将其提供给队列?

# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch(
      [single_image, single_label],
      batch_size=32,
      num_threads=4,
      capacity=50000,
      min_after_dequeue=10000)

一些上下文:我目前有一个包含大约50K行的文件。我正在使用tf.train.string_input_producertf.decode_csv来读取csv中的行,但是对于作为参数提供给tf.train.shuffle_batch的内容,包含所有行的单个行或张量感到困惑从文件中读取。

1 个答案:

答案 0 :(得分:1)

在您问题的代码段中,张量single_imagesingle_label对应于一个图片及其相关标签。从tf.train.shuffle_batch()image_batchlabel_batch返回的张量对应于32个可能包含在一起的* - 不同图像,以及32个相关标签。 TensorFlow在内部使用tf.RandomShuffleQueue来重新排列数据,并创建其他线程来评估single_imagesingle_label,以便将它们添加到此队列中。

tf.train.shuffle_batch()函数根据您传递的参数具有不同的行为。例如,如果传递enqueue_many=True,则tensors参数中的张量将被解释为元素批量,TensorFlow将在前导维度上将它们连接起来(因此每个张量必须具有相同的大小)第0维)。使用enqueue_many=True,您可以将整个数据集传递给tf.train.shuffle_batch(),也可以传递批量元素(例如使用tf.ReaderBase.read_up_to())。

*我说“可能不同”,因为您正在使用批处理函数的洗牌版本,如果您的数据集与capacitymin_after_dequeue参数相比较小,则可能是您将在一个批次中看到同一示例的多个副本。