我已经测试了这样的代码:
# filename_queue comes from tf.train.string_input_producer
features, labels, filename_queue = read_batch_data(file_list, 10)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
counter = 0
try:
while not coord.should_stop():
counter = counter + 1
value = features.eval()
if counter % 1000 == 0:
# check whether new data has been inserted into the queue
print counter, sum(value)
index = (counter / 1000) % 3
enqueue_op = filename_queue.enqueue(['a%d.csv' % index])
sess.run([enqueue_op])
except tf.errors.OutOfRangeError
...
但是看起来图表仍然使用原始文件队列,并且从不读取新数据。
答案 0 :(得分:2)
我怀疑你有一个带有旧名称的大型预取缓冲区,所以当你添加一个新文件名时,它只会在预取缓冲区耗尽后才能看到。默认情况下,32
将通过名称集无限循环,它将填充大小为FIFOQueue
的预取缓冲区。
如果您想修改列表,则更容易使用string_input_producer.
并手动填写示例,而不是config.operation_timeout_in_ms=5000
请注意不要提供足够的示例并悬挂主线程,可能想为您的会话设置def dump_numbers_to_file(fname, start_num, end_num):
with open(fname, 'w') as f:
for i in range(start_num, end_num):
f.write(str(i)+"\n")
num_files=10
num_entries_per_file=10
file_root="/temp/pipeline"
os.system('mkdir -p '+file_root)
for fi in range(num_files):
fname = file_root+"/"+str(fi)
dump_numbers_to_file(fname, fi*num_entries_per_file, (fi+1)*num_entries_per_file)
例如,/ temp / pipeline / 0文件中的以下示例条目一次(文件中有10个条目),之后将从/ temp / pipeline / 1打印条目
创建一些测试数据
def create_session():
"""Resets local session, returns new InteractiveSession"""
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction=0.3 # don't hog all vRAM
config.operation_timeout_in_ms=15000 # terminate on long hangs
sess = tf.InteractiveSession("", config=config)
return sess
帮助实用程序创建会话
tf.reset_default_graph()
filename_queue = tf.FIFOQueue(capacity=10, dtypes=[tf.string])
enqueue_op = filename_queue.enqueue("/temp/pipeline/0")
sess = create_session()
sess.run(enqueue_op)
sess.run(enqueue_op)
# filename queue now has [/temp/pipeline/0, /temp/pipeline/0]
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
numeric_val, = tf.decode_csv(value, record_defaults=[[-1]])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print sess.run([numeric_val])
# filename queue now has [/temp/pipeline/0]
print 'size before', sess.run(filename_queue.size())
sess.run(filename_queue.enqueue("/temp/pipeline/1"))
# filename queue now has [/temp/pipeline/0, /temp/pipeline/1]
print 'size after', sess.run(filename_queue.size())
for i in range(10):
print sess.run([numeric_val])
# filename queue now has [/temp/pipeline/1]
for i in range(10):
print sess.run([numeric_val])
# filename queue is now empty, next sess.run([numeric_val]) would hang
coord.request_stop()
coord.join(threads)
运行您的示例
[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
size before 1
size after 2
[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
[10]
[11]
[12]
[13]
[14]
[15]
[16]
[17]
[18]
[19]
你应该看到
{{1}}