Tensorflow:会话之间正确的队列关闭

时间:2016-05-27 09:45:13

标签: python tensorflow

我有两个图表(一个用于培训,一个用于评估),它们使用相同的推理网络,但输入数据不同。输入是通过读取二进制数据文件创建的。这些图表是一个接一个地运行并在单独的会话中运行。它实际上似乎有效,但我无法摆脱每次关闭会话时发生的警告:

W tensorflow/core/common_runtime/executor.cc:1102] 0x7fb7ac082980 Compute status: Cancelled: Enqueue operation was cancelled
 [[Node: input_producer/input_producer_EnqueueMany = QueueEnqueueMany[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer, input_producer/RandomShuffle)]]
I tensorflow/core/kernels/queue_base.cc:286] Skipping cancelled enqueue attempt

我的问题是,如何正确关闭所有队列和线程以避免警告。我写了一个小例子,它展示了我正在做的事情,希望能让你重现我的问题。

import tensorflow as tf
import struct
import numpy as np


file_name = 'test.bin'

def write_binary_file():
    with open(file_name, 'w') as f:
        for i in range(20):
            # image = np.ones([3, 3])*i
            image = np.zeros([3, 3])
            image[:, 0] = i
            image = image.astype('uint8')
            for u in range(3):
                for v in range(3):
                    f.write('%s' % struct.pack('B', image[u, v]))

# write binary file
write_binary_file()
print "size of one entry"
print 9
print "size of file"
print 20*9

# GRAPH DEFINITION
filename_queue = tf.train.string_input_producer([file_name])

reader = tf.FixedLengthRecordReader(header_bytes=0, record_bytes=9)
key, value = reader.read(filename_queue)

value_uint8 = tf.reshape(tf.decode_raw(value, tf.uint8), [3, 3])

# FIRST session
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(5):
    key_v, value_uint8_v = sess.run([key, value_uint8])
    print key_v
    print value_uint8_v

coord.request_stop() # *HERE* the warning is thrown
coord.join(threads, stop_grace_period_secs=5)
sess.close()

# SECOND session
sess2 = tf.Session()
coord2 = tf.train.Coordinator()
threads2 = tf.train.start_queue_runners(sess=sess2, coord=coord2)
for i in range(5):
    key_v, value_uint8_v = sess2.run([key, value_uint8])
    print key_v
    print value_uint8_v

coord2.request_stop()
coord2.join(threads2, stop_grace_period_secs=5)
sess2.close()

您知道如何修复警告,还是知道更改数据源的更好方法?

1 个答案:

答案 0 :(得分:0)

如雅罗斯拉夫所说;我相信这个警告在TF的最新版本中已得到解决。