我有一个两类分类问题,我想知道这两个类中每个类的正确分类的数量:
correct = tf.nn.in_top_k(logits, labels, 1)
例如,我分享了90到10个class1到class2的例子,我希望得到类似correct_class1
和correct_class2
答案 0 :(得分:0)
您实际上是在计算直方图。
这是一个粗略的想法。假设我们有3个班级,
标签= [1,3,2,2,3,1,1,1]
correct = tf.nn.in_top_k(logits,labels,1)= [1,1,0,0,1,0,1,1]
第1步: 计算每个类(计算标签的直方图):[4(第一类),第2类(第二类),第2类(第三类)]
第2步: 正确预测,标签的元素乘积和正确:[1,3,0,0,3,0,1,1] = tf.mul(标签,正确)
第3步: 每个类的正确计数(计算步骤2结果的直方图):[3(零级:错误预测),3级(1级),0级(2级),2级(3级)]
步骤4(从步骤1和3的结果): 结果:[3 / 4,0 / 2,2 / 2] = [0.75,0,1]
numpy.histogram或tf.histogram_summary可能很方便实现这一点。