tensorflow可以在训练中动态地将文件添加到FIFOQueue中

时间:2016-04-17 12:20:38

标签: tensorflow

我已经测试了这样的代码:



    # filename_queue comes from tf.train.string_input_producer
    features, labels, filename_queue = read_batch_data(file_list, 10)
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        counter = 0
        try:
            while not coord.should_stop():
                counter = counter + 1
                value = features.eval()
                if counter % 1000 == 0:
                    # check whether new data has been inserted into the queue
                    print counter, sum(value)
                    index = (counter / 1000) % 3
                    enqueue_op = filename_queue.enqueue(['a%d.csv' % index])
                    sess.run([enqueue_op])
         except tf.errors.OutOfRangeError
             ...

但是看起来图表仍然使用原始文件队列,并且从不读取新数据。

1 个答案:

答案 0 :(得分:2)

我怀疑你有一个带有旧名称的大型预取缓冲区,所以当你添加一个新文件名时,它只会在预取缓冲区耗尽后才能看到。默认情况下,32将通过名称集无限循环,它将填充大小为FIFOQueue的预取缓冲区。

如果您想修改列表,则更容易使用string_input_producer.并手动填写示例,而不是config.operation_timeout_in_ms=5000请注意不要提供足够的示例并悬挂主线程,可能想为您的会话设置def dump_numbers_to_file(fname, start_num, end_num): with open(fname, 'w') as f: for i in range(start_num, end_num): f.write(str(i)+"\n") num_files=10 num_entries_per_file=10 file_root="/temp/pipeline" os.system('mkdir -p '+file_root) for fi in range(num_files): fname = file_root+"/"+str(fi) dump_numbers_to_file(fname, fi*num_entries_per_file, (fi+1)*num_entries_per_file)

例如,/ temp / pipeline / 0文件中的以下示例条目一次(文件中有10个条目),之后将从/ temp / pipeline / 1打印条目

创建一些测试数据

def create_session():
  """Resets local session, returns new InteractiveSession"""
  config = tf.ConfigProto(log_device_placement=True)
  config.gpu_options.per_process_gpu_memory_fraction=0.3 # don't hog all vRAM
  config.operation_timeout_in_ms=15000   # terminate on long hangs
  sess = tf.InteractiveSession("", config=config)
  return sess

帮助实用程序创建会话

tf.reset_default_graph()
filename_queue = tf.FIFOQueue(capacity=10, dtypes=[tf.string])
enqueue_op = filename_queue.enqueue("/temp/pipeline/0")
sess = create_session()
sess.run(enqueue_op)
sess.run(enqueue_op)
# filename queue now has [/temp/pipeline/0, /temp/pipeline/0]
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
numeric_val, = tf.decode_csv(value, record_defaults=[[-1]])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

for i in range(10):
  print sess.run([numeric_val])

# filename queue now has [/temp/pipeline/0]
print 'size before', sess.run(filename_queue.size())
sess.run(filename_queue.enqueue("/temp/pipeline/1"))

# filename queue now has [/temp/pipeline/0, /temp/pipeline/1]
print 'size after', sess.run(filename_queue.size())

for i in range(10):
  print sess.run([numeric_val])

# filename queue now has [/temp/pipeline/1]

for i in range(10):
  print sess.run([numeric_val])

# filename queue is now empty, next sess.run([numeric_val]) would hang

coord.request_stop()
coord.join(threads)

运行您的示例

[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
size before 1
size after 2
[0]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
[10]
[11]
[12]
[13]
[14]
[15]
[16]
[17]
[18]
[19]

你应该看到

{{1}}