如何检查我在一个时代内完成了所有数据?

时间:2018-05-25 19:20:12

标签: python tensorflow

如何在sess.run中未指定num_epochs的情况下检查slice_input_producer是否在一个时期内完成了所有数据?

我认为使用while not coord.should_stop()(参见下面的代码)应该在它通过所有数据后停止,但它永远不会停止。它只是一遍又一遍地重复。

唯一的解决方案是在num_epochs=1中设置slice_input_producer而不是在tf.errors.OutOfRangeError中捕获它,因此我会知道它是通过了所有数据吗?

batch_size = 3
input_queue = tf.train.slice_input_producer([image_filenames, label_filenames], 
                                                num_epochs=None,
                                                shuffle=True)
image_batch, label_batch = generate_batch(input_queue, batch_size) 

with tf.Session(...) as sess:
    sess.run(tf.global_variables_initializer())            
    sess.run(tf.local_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        while not coord.should_stop():
            images,labels = sess.run([train_image_batch, train_label_batch])
            pltShow(images, labels)

    except tf.errors.OutOfRangeError:
        print('OutOfRangeError')
    finally:
        print("requesting threads to stop.")
        coord.request_stop()

UPDATE1

我尝试在num_epochs=1中设置slice_input_producer然后执行此操作,但它没有帮助。它在第一个纪元后停止

epoch = 0
while epoch < 10:
    try:             
        while not coord.should_stop():
            images,labels = sess.run([train_image_batch, train_label_batch])
            pltShow(images, labels)

    except tf.errors.OutOfRangeError:
        print('OutOfRangeError')
    finally:
        print("requesting threads to stop.")

    epoch += 1

0 个答案:

没有答案