Tensorflow Shuffle Batch Non Deterministic

时间:2018-01-08 18:52:33

标签: tensorflow

我正试图从tf.train.shuffle_batch()获得确定性行为。相反,我可以使用工作正常的tf.train.batch()(总是与元素的顺序相同),但我需要从多个tf记录中获取示例,因此我被shuffle_batch()困住了。

我正在使用:

random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
data_entries = tf.train.shuffle_batch(
    [data], batch_size=batch_size, num_threads=1, capacity=512,
    seed=57, min_after_dequeue=32)

但每次我重新启动脚本时,我得到的结果略有不同(不完全不同,但大约20%的元素排序错误)。 有什么我想念的吗?

编辑:解决了!请参阅下面的答案!

2 个答案:

答案 0 :(得分:0)

也许我误解了某些内容,但您可以使用tf.train.string_input_producer()在队列中收集多个tf记录,然后将示例读入张量并最终使用tf.train.batch()

看看CIFAR-10 input

答案 1 :(得分:0)

回答我自己的问题:

首先,shuffle_batch的原因是非确定性的:

  • 我申请批次的时间本质上是随机的。
  • 在那个时候,可以使用随机数量的张量。
  • Tensorflow调用播种的随机播放操作,但根据项目数量,它将返回不同的顺序。

因此,无论种子如何,顺序总是不同,除非元素的数量是不变的。所以解决方案是保持元素数量不变,但我们如何做呢?

设置capacity=min_after_dequeue+batch_size。这将强制Tensorflow填充队列,直到它达到满容量,然后才能使项目出列。因此,在随机播放操作时,我们有capacity个许多项,这是一个常数。

那我们为什么这样做呢?因为一个tf.record包含许多示例,但我们需要来自多个tf.records的示例。对于正常的批处理,我们首先得到一个记录的所有示例,然后是下一个记录的所有示例。这也意味着我们应该将min_after_dequeue设置为大于一个tf.record中的项目数。在我的示例中,我在一个文件中有50个示例,因此我设置了min_after_dequeue=2048

或者,我们也可以在创建tf.records之前对这些示例进行随机播放,但这对我来说是不可能的,因为我从多个目录中读取了tf.records(每个目录都有自己的数据集)。

最后注意:您还应该使用批量大小1进行超级保存。