如何显示keras中每个时代的每个班级的准确性

时间:2018-05-31 03:38:55

标签: keras mnist

我尝试训练mnist数据集 在训练时,我希望显示每个时期的每个课程的准确性,而不是整个数据集的准确性。 我们应该做什么?改变回调()? 谢谢你!

1 个答案:

答案 0 :(得分:0)

最后自己弄清楚xD 使用回调可以解决这个问题 以mnist数据集为例,我想在这里显示数字5类精度, 执行以下操作:

class TestCallback(Callback):
def __init__(self, test_data):
    self.test_data = test_data

def on_epoch_end(self, epoch, logs={}):
    x, y = self.test_data
    pred = self.model.predict(x)
    true = y
    prediction = np.argmax(pred,axis=1)
    label = np.argmax(true,axis=1)
    acc = 0
    tar = label[label==5]
    size_of_5 = len(tar)
    print("there are %d of 5"%(size_of_5))
    for i in range(len(label)):
        if label[i]==5:
            if prediction[i]==5:
                 acc += 1/size_of_5
    print('\n digit 5 accuracy:{}\n'.format(acc))