Tensorflow:tf.argmax()是预测值还是最大值?

时间:2018-06-30 10:56:40

标签: tensorflow

我正在学习张量流,在各种示例中,我已经看到要使用log tf.argmax(logits, 1)从logit获得预测。根据我的理解,logits是概率值,tf.argmax()将给出指定轴上最大值的 index 。但是,如何使用索引代替概率值。我们不应该将最大值用作预测吗?

但是我已经看到上面的代码可以正常工作。我确定我在这里缺少一些基础知识。有人可以举例说明吗?

1 个答案:

答案 0 :(得分:5)

通常logits是分类网络的输出张量,其内容为未归一化(不在0到1之间缩放)的概率。

tf.argmax为您提供沿指定轴的最大值索引。

您可以将logits转换为伪概率(这只是张量,其值的总和为1),并将其作为输入送入argmax:

top = tf.argmax(tf.nn.softmax(logits), 1)

但是最后,结果与直接馈入未归一化的概率相同:

top = tf.argmax(logits, 1)

但是,您必须使用argmax才能了解网络为该输入预测的类,这是唯一的方法,您不能仅使用概率(归一化或非归一化)。

只需考虑一个logits张量:

logits = [ [ 10, 500, -1, 0.5, 12 ] ]

张量形状为[1,5]。只需查看张量值,您就可以轻松地了解到,置信度最高的类是与位置1相关联的类,其值为500。

如何提取最高值的位置?您必须使用argmax:

top = tf.argmax(logits, 1)

执行后将返回值1

摘要:      logits的值为Scores,索引为Classes。通过使用argmax(),可以获得predicted class