如何在sess.run
中未指定num_epochs
的情况下检查slice_input_producer
是否在一个时期内完成了所有数据?
我认为使用while not coord.should_stop()
(参见下面的代码)应该在它通过所有数据后停止,但它永远不会停止。它只是一遍又一遍地重复。
唯一的解决方案是在num_epochs=1
中设置slice_input_producer
而不是在tf.errors.OutOfRangeError
中捕获它,因此我会知道它是通过了所有数据吗?
batch_size = 3
input_queue = tf.train.slice_input_producer([image_filenames, label_filenames],
num_epochs=None,
shuffle=True)
image_batch, label_batch = generate_batch(input_queue, batch_size)
with tf.Session(...) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
images,labels = sess.run([train_image_batch, train_label_batch])
pltShow(images, labels)
except tf.errors.OutOfRangeError:
print('OutOfRangeError')
finally:
print("requesting threads to stop.")
coord.request_stop()
UPDATE1
我尝试在num_epochs=1
中设置slice_input_producer
然后执行此操作,但它没有帮助。它在第一个纪元后停止
epoch = 0
while epoch < 10:
try:
while not coord.should_stop():
images,labels = sess.run([train_image_batch, train_label_batch])
pltShow(images, labels)
except tf.errors.OutOfRangeError:
print('OutOfRangeError')
finally:
print("requesting threads to stop.")
epoch += 1