Tensorflow:按类确定正确的分类份额

时间:2016-08-10 11:20:37

标签: tensorflow

我有一个两类分类问题,我想知道这两个类中每个类的正确分类的数量:

correct = tf.nn.in_top_k(logits, labels, 1)

例如,我分享了90到10个class1到class2的例子,我希望得到类似correct_class1correct_class2

的内容。

1 个答案:

答案 0 :(得分:0)

您实际上是在计算直方图。

这是一个粗略的想法。假设我们有3个班级,

标签= [1,3,2,2,3,1,1,1]

correct = tf.nn.in_top_k(logits,labels,1)= [1,1,0,0,1,0,1,1]

第1步: 计算每个类(计算标签的直方图):[4(第一类),第2类(第二类),第2类(第三类)]

第2步: 正确预测,标签的元素乘积和正确:[1,3,0,0,3,0,1,1] = tf.mul(标签,正确)

第3步: 每个类的正确计数(计算步骤2结果的直方图):[3(零级:错误预测),3级(1级),0级(2级),2级(3级)]

步骤4(从步骤1和3的结果): 结果:[3 / 4,0 / 2,2 / 2] = [0.75,0,1]

numpy.histogram或tf.histogram_summary可能很方便实现这一点。