我有一个只有640条记录的小型tfrecords文件。下面的代码挂了,我不知道它有什么问题:
def read_from_tfrecord(tfrecord_file):
tfrecord_file_queue = tf.train.string_input_producer(tfrecord_file, name = 'queue')
reader = tf.TFRecordReader()
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
tfrecord_features = tf.parse_single_example(tfrecord_serialized,
features = {'label': tf.FixedLenFeature([], tf.string),
'snippet': tf.FixedLenFeature([], tf.string)}, name = 'features')
snippet = tf.decode_raw(tfrecord_features['snippet'], tf.float32)
snippet = tf.reshape(snippet, [x_height, x_width, num_channels])
label = tf.decode_raw(tfrecord_features['label'], tf.int32)
label = tf.reshape(label, [2])
snippets_shuffled, labels_shuffled = tf.train.shuffle_batch([snippet, label],
batch_size = 2,
capacity = 10,
num_threads = 1,
min_after_dequeue = 4)
return snippets_shuffled, labels_shuffled
和
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(tf.global_variables_initializer())
snippet, label = read_from_tfrecord(['./TFRecordFile/test_tmp.tfrecords'])
print('1') # it prints 1
a, b = sess.run([snippet, label]) # it hangs here!
print('2') # it never prints 2
任何帮助表示感谢。
答案 0 :(得分:0)
好的,我解决了这个问题。我花了很多时间来完成这项工作。所以,我在这里发布答案以防其他人面临类似的问题。除了Seven建议之外,
我必须添加tf.train.start_queue_runners(sess)
。代码如下所示:
snippet, label = read_from_tfrecord(['./TFRecordFile/test_tmp.tfrecords'])
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.train.start_queue_runners(sess) # <==== Add This Line
sess.run(tf.global_variables_initializer())
print('1') # it prints 1
a, b = sess.run([snippet, label]) # it hangs here!
print('2') # it never prints 2