Caffe的报告准确性是否可靠?

时间:2016-07-02 18:18:32

标签: caffe confusion-matrix

我最近试图获得一个训练模型的混淆矩阵,看看它有多精确。我下载了这个脚本并喂了我的模型。 令我惊讶的是,脚本计算的准确度与一个Caffe报告非常不同。

我使用this script来计算混淆矩阵,然而,这也报告了准确性,问题是此脚本报告的准确度与{{1}报告的准确度不同}!
例如,Caffe报告Caffe的准确度为92.34%,而当模型被输入脚本以计算混淆矩阵及其准确度时,它会导致例如86.5% 。!

这些准确性中的哪一个是正确的,并且可以在论文中报告或与其他论文的结果进行比较,例如那些here

我还看到了一些奇怪的东西,我训练了两个相同的模型,只有一个区别,一个使用CIFAR10,另一个使用xavier进行初始化。
第一个报告的准确度为94.25,另一个报告的报告为94.26。当这些模型被馈送到我上面链接的脚本时,用于混淆矩阵计算。他们的准确率分别为89.2%和87.4%! 这是正常的吗?这是什么原因? MSRA?

我真的不知道caffe报道的准确度是否真实值得和可靠。如果有人能够解释这个问题,我将不胜感激。

P.N: 脚本的准确性计算为(complete script):

msra

哪个imho可以正确。正确预测的数量除以数据集中的实例总数。

1 个答案:

答案 0 :(得分:1)

我找到了原因。 Caffe生成的准确性与所讨论的脚本产生的精确度之间不匹配的原因仅仅是因为平均减法,这是在caffe中完成的,而不是在脚本中完成的。 This是脚本的修改版本,它考虑到了这一点,希望一切都很好。

# Author: Axel Angel, copyright 2015, license GPLv3.
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
# Seyyed Hossein Hasan Pour
# Coderx7@Gmail.com
# 7/3/2016 

import sys
import caffe
import numpy as np
import lmdb
import argparse
from collections import defaultdict

def flat_shape(x):
    "Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
    return x.reshape(filter(lambda s: s > 1, x.shape))

def lmdb_reader(fpath):
    import lmdb
    lmdb_env = lmdb.open(fpath)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()

    for key, value in lmdb_cursor:
        datum = caffe.proto.caffe_pb2.Datum()
        datum.ParseFromString(value)
        label = int(datum.label)
        image = caffe.io.datum_to_array(datum).astype(np.uint8)
        yield (key, flat_shape(image), label)

def leveldb_reader(fpath):
    import leveldb
    db = leveldb.LevelDB(fpath)

    for key, value in db.RangeIter():
        datum = caffe.proto.caffe_pb2.Datum()
        datum.ParseFromString(value)
        label = int(datum.label)
        image = caffe.io.datum_to_array(datum).astype(np.uint8)
        yield (key, flat_shape(image), label)

def npz_reader(fpath):
    npz = np.load(fpath)

    xs = npz['arr_0']
    ls = npz['arr_1']

    for i, (x, l) in enumerate(np.array([ xs, ls ]).T):
        yield (i, x, l)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--proto', type=str, required=True)
    parser.add_argument('--model', type=str, required=True)
    parser.add_argument('--mean', type=str, required=True)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--lmdb', type=str, default=None)
    group.add_argument('--leveldb', type=str, default=None)
    group.add_argument('--npz', type=str, default=None)
    args = parser.parse_args()

# Extract mean from the mean image file
    mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
    f = open(args.mean, 'rb')
    mean_blobproto_new.ParseFromString(f.read())
    mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
    f.close()

    count = 0
    correct = 0
    matrix = defaultdict(int) # (real,pred) -> int
    labels_set = set()

   # CNN reconstruction and loading the trained weights 
    net = caffe.Net(args.proto, args.model, caffe.TEST)
    caffe.set_mode_cpu()
    print "args", vars(args)
    if args.lmdb != None:
        reader = lmdb_reader(args.lmdb)
    if args.leveldb != None:
        reader = leveldb_reader(args.leveldb)
    if args.npz != None:
        reader = npz_reader(args.npz)

    for i, image, label in reader:
        image_caffe = image.reshape(1, *image.shape)
        out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
        plabel = int(out['prob'][0].argmax(axis=0))

        count += 1
        iscorrect = label == plabel
        correct += (1 if iscorrect else 0)
        matrix[(label, plabel)] += 1
        labels_set.update([label, plabel])

        if not iscorrect:
            print("\rError: i=%s, expected %i but predicted %i" \
                    % (i, label, plabel))

        sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count))
        sys.stdout.flush()

    print(", %i/%i corrects" % (correct, count))

    print ""
    print "Confusion matrix:"
    print "(r , p) | count"
    for l in labels_set:
        for pl in labels_set:
            print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)])