cntk

时间:2017-08-19 16:01:10

标签: python deep-learning cntk

我们以TrainResNet_CIFAR10.py为例来学习cntk。我们创建了两个方法,eval_metric和calc_error,如下所示:

def eval_metric(trainer, reader_test, test_epoch_size, label_var, input_map) :
    # Evaluation parameters
    minibatch_size = 16

    # process minibatches and evaluate the model
    metric_numer    = 0
    metric_denom    = 0
    sample_count    = 0

    while sample_count < test_epoch_size:
        current_minibatch = min(minibatch_size, test_epoch_size - sample_count)
        # Fetch next test min batch.
        data = reader_test.next_minibatch(current_minibatch, input_map=input_map)
        # minibatch data to be trained with
        metric_numer += trainer.test_minibatch(data) * current_minibatch
        metric_denom += current_minibatch
        # Keep track of the number of samples processed so far.
        sample_count += data[label_var].num_samples

    return metric_numer / metric_denom

def calc_error(trainer, fileList, mean_value, test_size) :
    if (len(fileList) != test_size) :
        return 0

    n   = 0
    m = 0
    while n < test_size:
        c = evalute(trainer, fileList[n].filename, mean_value);
        if (c != fileList[n].classID) :
            m += 1 
        n += 1

    return m / test_size

def evalute(trainer, img_name, mean_value) :
    rgb_image = np.asarray(Image.open(img_name), dtype=np.float32) - mean_value
    bgr_image = rgb_image[..., [2, 1, 0]]
    pic = np.ascontiguousarray(np.rollaxis(bgr_image, 2))
    probs = trainer.eval({trainer.arguments[0]:[pic]})
    predictions = np.squeeze(probs)
    top_class = np.argmax(predictions)
    return top_class

我们认为test_minibatch(data)返回不正确结果的百分比,这两种方法应该给出类似的结果。我的问题是:

  1. trainer.test_minibatch(数据)返回什么?
  2. 对于CIFAR-10测试图像,两种方法之间的差异在10%以内,但对于我们自己的样本图像,它们具有64x64x3和4个类别,差异超过100%。什么可能导致巨大差异?
  3. 如果我们将培训师直接传递给calc_error,它会在评估期间出错。我们必须在调用calc_error之前先保存和load_model,为什么?

1 个答案:

答案 0 :(得分:0)

trainer.test_minibatch返回损失的平均值(或通常是第一个参数)。

调用test_minibatch后,还可以使用这些方法: trainer.previous_minibatch_loss_averagetrainer.previous_minibatch_sample_counttrainer.previous_minibatch_evaluation_average

差异可能来自预处理。 mean_value与您训练网络时的相同吗?它是RGB顺序还是BGR顺序?

您是否考虑过将评估集缩减为单个图像并确认您使用阅读器获得完全相同的输出并手动加载图像?