tensorflow教程代码中的评估背后的基本原理cifar10_eval.py

时间:2017-05-11 03:34:55

标签: tensorflow

在TF的官方教程代码'cifar10'中,有一个评估片段:

def evaluate():    
with tf.Graph().as_default() as g:
            # Get images and labels for CIFAR-10.
            eval_data = FLAGS.eval_data == 'test'
            images, labels = cifar10.inputs(eval_data=eval_data)

            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = cifar10.inference(images)

            # Calculate predictions.
            top_k_op = tf.nn.in_top_k(logits, labels, 1)

            # Restore the moving average version of the learned variables for eval.
            variable_averages = tf.train.ExponentialMovingAverage(
                cifar10.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.summary.merge_all()

            summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

            while True:
                eval_once(saver, summary_writer, top_k_op, summary_op)
                if FLAGS.run_once:
                    break
                time.sleep(FLAGS.eval_interval_secs)

在运行时,它会评估一批测试样本并在控制台中每隔一个eval_interval_secs打印出“精度”,我的问题是:

  1. 每次执行eval_once()时,一批样本(128)从数据队列中出队,但为什么在批次足够后我没有看到评估停止,10000/128 + 1 = 79批次?我认为应该在79批次之后停止。

  2. 来自前79个样本的批次是否互相排斥?我想是这样,但想仔细检查一下。

  3. 如果每个批次确实从数据队列中出列,那么79次采样后的样本是什么?再次从整个重复数据队列中随机抽样?

  4. 因为in_top_k()正在接受一些非标准化的logit值并输出一串布尔值,这掩盖了softmax()+阈值的内部转换。是否存在用于此类显式计算的TF操作?理想情况下,能够调整阈值并查看不同的分类结果非常有用。

  5. 请帮忙。 谢谢!

1 个答案:

答案 0 :(得分:1)

  1. 您可以在"输入"中看到以下行。 def of cifar10_input.py

    filename_queue = tf.train.string_input_producer(filenames) 
    

    有关tf.train.string_input_producer的更多信息:

    string_input_producer(
        string_tensor,
        num_epochs=None,
        shuffle=True,
        seed=None,
        capacity=32,
        shared_name=None,
        name=None,
        cancel_op=None
     )
    

    num_epochs:在生成OutOfRange错误之前,从string_tensor num_epochs次生成每个字符串。如果未指定,string_input_producer可以循环遍历string_tensor中的字符串无限次。

    在我们的例子中,没有指定num_epochs。这就是为什么它不会在几批后停止的原因。它可以无限次运行。

  2. 默认情况下,shfle选项在tf.train.string_input_producer中设置为True。因此,它首先将数据混洗,然后一次又一次地复制10K文件名。

    因此,它是相互排斥的。您可以打印文件名以查看此内容。

  3. 如1中所述,它们是重复的样品。 (不是任何随机数据)

  4. 您可以避免使用tf.nn.in_top_k。使用tf.nn.softmax和tf.greater_equal获取softmax值高于特定阈值的布尔张量。

  5. 我希望这会有所帮助。如果有任何误解,请评论。