我正在学习张量流,在各种示例中,我已经看到要使用log tf.argmax(logits, 1)
从logit获得预测。根据我的理解,logits
是概率值,tf.argmax()
将给出指定轴上最大值的 index 。但是,如何使用索引代替概率值。我们不应该将最大值用作预测吗?
但是我已经看到上面的代码可以正常工作。我确定我在这里缺少一些基础知识。有人可以举例说明吗?
答案 0 :(得分:5)
通常logits
是分类网络的输出张量,其内容为未归一化(不在0到1之间缩放)的概率。
tf.argmax
为您提供沿指定轴的最大值索引。
您可以将logits
转换为伪概率(这只是张量,其值的总和为1),并将其作为输入送入argmax:
top = tf.argmax(tf.nn.softmax(logits), 1)
但是最后,结果与直接馈入未归一化的概率相同:
top = tf.argmax(logits, 1)
但是,您必须使用argmax才能了解网络为该输入预测的类,这是唯一的方法,您不能仅使用概率(归一化或非归一化)。
只需考虑一个logits张量:
logits = [ [ 10, 500, -1, 0.5, 12 ] ]
张量形状为[1,5]。只需查看张量值,您就可以轻松地了解到,置信度最高的类是与位置1相关联的类,其值为500。
如何提取最高值的位置?您必须使用argmax:
top = tf.argmax(logits, 1)
执行后将返回值1
摘要:
logits的值为Scores
,索引为Classes
。通过使用argmax()
,可以获得predicted class