使用队列从多个输入文件中统一采样

时间:2015-11-15 02:25:31

标签: tensorflow

我的数据集中的每个类都有一个序列化文件。我想使用队列来加载这些文件中的每一个,然后将它们放在一个RandomShuffleQueue中,它将把它们拉下来,这样我就会从每个类中随机混合一些例子。我认为这段代码可行。

在此示例中,每个文件都有10个示例。

filenames = ["a", "b", ...]

with self.test_session() as sess:
  # for each file open a queue and get that
  # queue's results. 
  strings = []
  rq = tf.RandomShuffleQueue(1000, 10, [tf.string], shapes=())
  for filename in filenames:
    q = tf.FIFOQueue(99, [tf.string], shapes=())
    q.enqueue([filename]).run()
    q.close().run()
    # read_string just pulls a string from the file
    key, out_string = input_data.read_string(q, IMAGE_SIZE, CHANNELS, LABEL_BYTES)
    strings.append(out_string)

    rq.enqueue([out_string]).run()

  rq.close().run()
  qs = rq.dequeue()
  label, image = input_data.string_to_data(qs, IMAGE_SIZE, CHANNELS, LABEL_BYTES)
  for i in range(11):
    l, im = sess.run([label, image])
    print("L: {}".format(l)

这适用于10次调用,但是在11日它说队列是空的。

我认为这是由于我对这些队列运作的误解。我向RandomShuffleQueue添加了10个变量,但每个变量本身都是从队列中提取的,所以我假设在每个文件队列都为空之前,队列不会被清空。

我在这里做错了什么?

1 个答案:

答案 0 :(得分:3)

此问题的正确答案取决于您拥有的文件数量,文件大小以及文件大小的分布情况。

您的示例的直接问题是rq只为每个filename in filenames获取一个元素,然后关闭队列。我假设有10 filenames,因为rq.dequeue()每次调用rq时都会使用sess.run([label, image])的一个元素。rq.dequeue()由于队列已关闭,因此无法添加更多元素,并且rq.enqueue([out_string])操作的第11次激活失败。

一般的解决方案是您必须创建其他线程以在循环中继续运行QueueRunner。 TensorFlow包含一个用于简化此操作的N类,以及一些处理常见情况的函数。 documentation for threading and queues解释了它们的使用方式,using queues to read from files也有一些很好的信息。

至于您的特定问题,您可以采用的一种方法是创建N个读者(每个N个文件)。然后,您可以将tf.pack() min_after_dequeue个元素(每个读者一个)放入批处理中,并使用enqueue_many一次将批处理添加到容量足够大的tf.RandomShuffleQueue中和RandomShuffleQueue以确保课程之间充分混合。在k上调用dequeue_many(k)可以为您提供从每个文件中采样的一批sumOfA元素,但概率相等。