设置num_epochs时,Tensorflow string_input_producer行为异常

时间:2016-07-01 00:09:13

标签: tensorflow

在设置num_epochs时,string_input_producer似乎没有将字符串排队。

使用以下代码,程序打印[0],这是不对的。

`import tensorflow as tf
sess = tf.InteractiveSession()
filenames = ["1", "2", "3"]
filename_queue = tf.train.string_input_producer(filenames, num_epochs=10)
test_value = tf.convert_to_tensor(filename_queue.size())

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

print(sess.run([test_value]))`

但如果我拿出num_epochs,

`import tensorflow as tf
sess = tf.InteractiveSession()
filenames = ["1", "2", "3"]
filename_queue = tf.train.string_input_producer(filenames)
test_value = tf.convert_to_tensor(filename_queue.size())

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

print(sess.run([test_value]))`

按预期打印[3]。

还有其他人遇到同样的问题吗?

2 个答案:

答案 0 :(得分:0)

解决方案是添加以下行:

sess.run(tf.initialize_all_variables())

...在启动队列运行程序之前。 tf.train.string_input_producer()函数在内部创建TensorFlow variable以跟踪当前的纪元索引,并且必须在首次使用之前进行初始化(将在您启动队列运行程序时进行初始化)。

答案 1 :(得分:0)

如今,请使用

  

sess.run(tf.local_variables_initializer())

在启动队列运行器之前。 因为如果num_epochs不是None,tf.train.string_input_producer()会创建本地计数器。