Keras中子类别的准确性度量标准

时间:2019-05-17 05:35:00

标签: python machine-learning keras metrics tensor

我遇到了3类分类问题。让我们将它们定义为类0,1和2。就我而言,类0不重要-也就是说,归为类0的任何内容都不相关。但是,重要的是仅适用于类别1和2的准确性,准确性,召回率和错误率。我想定义一个准确性度量标准,该度量标准仅查看与1和2相关的数据的一部分,并提供给我一个度量因为模型是训练。我不是要求提供准确性或f1或精度/召回率的代码-这些是我发现的并且可以实现自己的代码。我要的是代码,可以帮助选择类别的子部分以执行这些指标。 在视觉上,带有混淆矩阵: 鉴于:

>  0   1   2
>0 10  3   4
>1 2   5   1
>2 8   5   9

我只想对以下子集进行训练中的准确性测量:

>  1   2
>1 5   1
>2 5   9

可能的想法: 将已分类的argmaxed y_pred和argmaxed y_true串联起来,删除所有出现0的实例,将它们重新分解回one_hot数组,并对剩余的数据进行简单的二进制精度?

编辑: 我试图通过此代码排除0类,但这没有任何意义。 0类被有效地包装到1类中(也就是说,0和1的真实正数最终都被标记为1)。还在寻求帮助-有人可以帮忙吗?

#this solution does not work :(
def my_acc(y_true, y_pred):
#excluding the 0-category
y_true_cust = y_true[:,np.r_[1:3]]
y_pred_cust = y_pred[:,np.r_[1:3]]
#binary accuracy source code, slightly edited
y_pred_cat = Ker.round(y_pred_cust)
eql_cust = Ker.equal(y_true_cust, y_pred_cust)
return Ker.mean(eql_cust, axis = -1)

@ Ashwin Geet D'Sa

correct_guesses_3cat = 10 + 5 + 9
print(correct_guesses_3cat)
24

total_guesses_3cat = 10+3+4+2+5+1+8+5+9
print(total_guesses_3cat)
47

accuracy_3cat = 24/47
print(accuracy_3cat)
51.1 %

correct_guesses_2cat =5 + 9
print(correct_guesses_2cat)
14

total_guesses_2cat = 5+1+5+9
print(total_guesses_2cat)
20

accuracy_2cat = 14/20
print(accuracy_2cat)
70.0 %

0 个答案:

没有答案