我的数据集中的每个类都有一个序列化文件。我想使用队列来加载这些文件中的每一个,然后将它们放在一个RandomShuffleQueue中,它将把它们拉下来,这样我就会从每个类中随机混合一些例子。我认为这段代码可行。
在此示例中,每个文件都有10个示例。
filenames = ["a", "b", ...]
with self.test_session() as sess:
# for each file open a queue and get that
# queue's results.
strings = []
rq = tf.RandomShuffleQueue(1000, 10, [tf.string], shapes=())
for filename in filenames:
q = tf.FIFOQueue(99, [tf.string], shapes=())
q.enqueue([filename]).run()
q.close().run()
# read_string just pulls a string from the file
key, out_string = input_data.read_string(q, IMAGE_SIZE, CHANNELS, LABEL_BYTES)
strings.append(out_string)
rq.enqueue([out_string]).run()
rq.close().run()
qs = rq.dequeue()
label, image = input_data.string_to_data(qs, IMAGE_SIZE, CHANNELS, LABEL_BYTES)
for i in range(11):
l, im = sess.run([label, image])
print("L: {}".format(l)
这适用于10次调用,但是在11日它说队列是空的。
我认为这是由于我对这些队列运作的误解。我向RandomShuffleQueue
添加了10个变量,但每个变量本身都是从队列中提取的,所以我假设在每个文件队列都为空之前,队列不会被清空。
我在这里做错了什么?
答案 0 :(得分:3)
此问题的正确答案取决于您拥有的文件数量,文件大小以及文件大小的分布情况。
您的示例的直接问题是rq
只为每个filename in filenames
获取一个元素,然后关闭队列。我假设有10 filenames
,因为rq.dequeue()
每次调用rq
时都会使用sess.run([label, image])
的一个元素。rq.dequeue()
由于队列已关闭,因此无法添加更多元素,并且rq.enqueue([out_string])
操作的第11次激活失败。
一般的解决方案是您必须创建其他线程以在循环中继续运行QueueRunner
。 TensorFlow包含一个用于简化此操作的N
类,以及一些处理常见情况的函数。 documentation for threading and queues解释了它们的使用方式,using queues to read from files也有一些很好的信息。
至于您的特定问题,您可以采用的一种方法是创建N
个读者(每个N
个文件)。然后,您可以将tf.pack()
min_after_dequeue
个元素(每个读者一个)放入批处理中,并使用enqueue_many
一次将批处理添加到容量足够大的tf.RandomShuffleQueue
中和RandomShuffleQueue
以确保课程之间充分混合。在k
上调用dequeue_many(k)
可以为您提供从每个文件中采样的一批sumOfA
元素,但概率相等。