我是tensorflow的新手,目前正在尝试使用csv格式的数据生成批次。
我遵循了Tensor Flow中的阅读数据教程(https://www.tensorflow.org/programmers_guide/reading_data),但我必须误解一些事情,因为我的代码永远存在。 我在教程中使用了read_my_file_format函数并且它有效。现在我想按照以下方式训练我的网络实际使用批处理:
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
print('read_my_file is done')
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
print('all done but the return')
return example_batch, label_batch
with tf.Session() as sess:
batch_size=5
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
batch_data,batch_label=sess.run(input_pipeline(file_name,batch_size=batch_size))
print('return is done')
print(batch_data,batch_label)
coord.request_stop()
coord.join(threads)
为了调试,在上面的代码中,我只是尝试打印生成的批处理,而不是将其提供给网络。凭借我的版画,我能够看到它挂在哪里:就在之前 return example_batch,label_batch。
我的神经网络准备就绪,我的数据已被处理,所以这是阻止我在我的项目(Supernovae Classification)中前进的唯一因素。你有什么建议或意见吗?我已经坚持了一段时间。
此外,如果需要,我的文件名中只有一个输入文件。
谢谢
答案 0 :(得分:1)
您需要初始化变量。
with tf.Session() as sess:
...
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(tf.global_variables_initializer())
...