我尝试训练mnist数据集 在训练时,我希望显示每个时期的每个课程的准确性,而不是整个数据集的准确性。 我们应该做什么?改变回调()? 谢谢你!
答案 0 :(得分:0)
最后自己弄清楚xD 使用回调可以解决这个问题 以mnist数据集为例,我想在这里显示数字5类精度, 执行以下操作:
class TestCallback(Callback):
def __init__(self, test_data):
self.test_data = test_data
def on_epoch_end(self, epoch, logs={}):
x, y = self.test_data
pred = self.model.predict(x)
true = y
prediction = np.argmax(pred,axis=1)
label = np.argmax(true,axis=1)
acc = 0
tar = label[label==5]
size_of_5 = len(tar)
print("there are %d of 5"%(size_of_5))
for i in range(len(label)):
if label[i]==5:
if prediction[i]==5:
acc += 1/size_of_5
print('\n digit 5 accuracy:{}\n'.format(acc))