tf.train.shuffle_batch永远挂起(使用tensorflow ver.1.4)

时间:2017-12-13 00:58:55

标签: tensorflow queue

我有一个只有640条记录的小型tfrecords文件。下面的代码挂了,我不知道它有什么问题:

def read_from_tfrecord(tfrecord_file):
   tfrecord_file_queue = tf.train.string_input_producer(tfrecord_file, name = 'queue')
   reader = tf.TFRecordReader()
   _, tfrecord_serialized = reader.read(tfrecord_file_queue)    
   tfrecord_features = tf.parse_single_example(tfrecord_serialized,
   features = {'label': tf.FixedLenFeature([], tf.string),                
               'snippet': tf.FixedLenFeature([], tf.string)}, name = 'features')

   snippet = tf.decode_raw(tfrecord_features['snippet'], tf.float32)
   snippet = tf.reshape(snippet, [x_height, x_width, num_channels])
   label = tf.decode_raw(tfrecord_features['label'], tf.int32)
   label = tf.reshape(label, [2])

   snippets_shuffled, labels_shuffled = tf.train.shuffle_batch([snippet, label], 
                                                                batch_size = 2, 
                                                                capacity = 10, 
                                                                num_threads = 1, 
                                                             min_after_dequeue = 4)
    return snippets_shuffled, labels_shuffled 

with tf.Session()  as sess:    
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
   sess.run(tf.global_variables_initializer())

   snippet, label = read_from_tfrecord(['./TFRecordFile/test_tmp.tfrecords'])
   print('1') # it prints 1
   a, b = sess.run([snippet, label]) # it hangs here!
   print('2') # it never prints 2

任何帮助表示感谢。

1 个答案:

答案 0 :(得分:0)

好的,我解决了这个问题。我花了很多时间来完成这项工作。所以,我在这里发布答案以防其他人面临类似的问题。除了Seven建议之外,

我必须添加tf.train.start_queue_runners(sess)。代码如下所示:

snippet, label = read_from_tfrecord(['./TFRecordFile/test_tmp.tfrecords'])    
with tf.Session()  as sess:    
       coord = tf.train.Coordinator()
       threads = tf.train.start_queue_runners(sess=sess, coord=coord)
       tf.train.start_queue_runners(sess) # <==== Add This Line
       sess.run(tf.global_variables_initializer())

       print('1') # it prints 1
       a, b = sess.run([snippet, label]) # it hangs here!
       print('2') # it never prints 2