Tensorflow:恢复会话后无法重新打开队列

时间:2016-06-04 15:44:12

标签: machine-learning tensorflow

我有一个训练有素的模型,我试图在单独的数据集上进行评估,而且我的输入管道出现问题。恢复会话并尝试加载第一批验证数据后,将引发以下错误:

tensorflow.python.framework.errors.OutOfRangeError: FIFOQueue '_2_input/batch/fifo_queue' is closed and has insufficient elements (requested 1024, current size 0)

我的代码是在cifar10_eval.py示例(see here)之后建模的。

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, record = reader.read(filename_queue)

    features = tf.parse_single_example(
        record,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
    })

    label = tf.cast(features['label'], tf.int32)
    image = tf.decode_raw(features['image_raw'], tf.uint8)

    image.set_shape([21*21*1])
    image = tf.cast(tf.reshape(image, (21, 21, 1)), tf.float32)

    return image, label

def inputs(train, batch_size, num_epochs):
    if train:
        filename = os.path.join(DATA_DIR, TRAIN_FILE)
    else:
        filename = os.path.join(DATA_DIR, TEST_FILE)

    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer(
            [filename], num_epochs=num_epochs, shuffle=train)

        example, label = read_and_decode(filename_queue)

        min_after_dequeue = 1000
        capacity = min_after_dequeue + 3 * batch_size

        if train:
            example_batch, label_batch = tf.train.shuffle_batch(
                [example, label], batch_size=batch_size, capacity=capacity,
                min_after_dequeue=min_after_dequeue)
        else:
            example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size,
                    capacity = capacity)

        return example_batch, label_batch

def evaluate_model():
    with tf.Graph().as_default():
        images, labels = inputs(train=False, batch_size=1024,
                num_epochs=NUM_EPOCHS)

        keep_prob = tf.Variable(1.0, name='keep_prob', trainable=False)

        logits = inference(images, keep_prob)
        training_error = batch_training_error(logits, labels)
        summary_op = tf.merge_all_summaries()

        sess = tf.Session()

        log_dir = os.path.join(SUMMARY_DIR, "eval2")
        writer = tf.train.SummaryWriter(log_dir, sess.graph)

        saver = tf.train.Saver()
        saver.restore(sess, 'checkpoint/model-1280')

        keep_prob.assign(1.0)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        #threads = []
        #for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        #    new_threads = qr.create_threads(sess, coord=coord, daemon=True, start=True)
        #    threads.extend(new_threads)

        try:
            step = 0
            while not coord.should_stop():
                err = sess.run(training_error)
                print("Step %d, batch training error: %.3f" % (step, err))

                if step % 10 == 0:
                    summary = sess.run(summary_op)
                    writer.add_summary(summary, global_step=step)
                    print('Summary written.')

                step += 1
        #except tf.errors.OutOfRangeError:
        #    print('Done training for %d epochs, %d steps.' % (NUM_EPOCHS, step))
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()

evaluate_model()

我是Tensorflow的新手,而且我无法理解我出错的地方。任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:2)

尝试替换

saver = tf.train.Saver()

saver = tf.train.Saver( tf.trainable_variables() )

那为我做了。我在评论中坚持我的解释。您需要避免恢复队列(input_producer)状态。我还必须附加像“global_step”这样的非训练,我想跟踪该列表。