在TensorFlow中,为什么tf.train.shuffle_batch永远会挂起并且不会返回批次?

时间:2017-08-10 05:14:39

标签: tensorflow

我是tensorflow的新手,目前正在尝试使用csv格式的数据生成批次。

我遵循了Tensor Flow中的阅读数据教程(https://www.tensorflow.org/programmers_guide/reading_data),但我必须误解一些事情,因为我的代码永远存在。 我在教程中使用了read_my_file_format函数并且它有效。现在我想按照以下方式训练我的网络实际使用批处理:

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(
         filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
print('read_my_file is done')
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
  [example, label], batch_size=batch_size, capacity=capacity,
  min_after_dequeue=min_after_dequeue)
print('all done but the return')
return example_batch, label_batch

with tf.Session() as sess:
batch_size=5
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

batch_data,batch_label=sess.run(input_pipeline(file_name,batch_size=batch_size))
print('return is done')
print(batch_data,batch_label)
coord.request_stop()
coord.join(threads)

为了调试,在上面的代码中,我只是尝试打印生成的批处理,而不是将其提供给网络。凭借我的版画,我能够看到它挂在哪里:就在之前     return example_batch,label_batch。

我的神经网络准备就绪,我的数据已被处理,所以这是阻止我在我的项目(Supernovae Classification)中前进的唯一因素。你有什么建议或意见吗?我已经坚持了一段时间。

此外,如果需要,我的文件名中只有一个输入文件。

谢谢

1 个答案:

答案 0 :(得分:1)

您需要初始化变量。

with tf.Session() as sess:
    ...
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(tf.global_variables_initializer())
    ...