KERAS“ sparse_categorical_crossentropy”问题

时间:2018-11-23 12:40:35

标签: python tensorflow machine-learning keras loss

a的浮点数为1.0或0.0。当我尝试用模型和sparse_categorical_crossentropy损失进行预测时,我得到如下信息: [[0.4846592 0.5153408]]

我怎么知道它预测什么类别?

1 个答案:

答案 0 :(得分:5)

您看到的这些数字是给定输入样本的每个类别的概率。例如,[[0.4846592 0.5153408]]表示给定的样本属于类别0,概率约为0.48,而属于样本1,概率约为0.51。因此,您希望以最高的概率上课,因此可以使用np.argmax来找出哪个索引(即0或1)是最大的索引:

import numpy as np

pred_class = np.argmax(probs, axis=-1) 

此外,这与模型的损失函数无关。这些概率由模型的最后一层给出,很有可能它使用softmax作为激​​活函数来将输出标准化为概率分布。