什么是logits,softmax和softmax_cross_entropy_with_logits?

时间:2015-12-12 14:03:27

标签: python machine-learning tensorflow

我正在浏览tensorflow API文档here。在tensorflow文档中,他们使用了一个名为logits的关键字。它是什么?在API文档中的许多方法中,它都像

一样编写
tf.nn.softmax(logits, name=None)

如果撰写的是logits只有Tensors,那么为什么要保留logits之类的其他名称?

另一件事是有两种我无法区分的方法。他们是

tf.nn.softmax(logits, name=None)
tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

它们之间有什么区别?这些文档对我来说并不清楚。我知道tf.nn.softmax的作用。但不是另一个。一个例子将非常有用。

8 个答案:

答案 0 :(得分:381)

Logits只是意味着该函数在早期图层的未缩放输出上运行,并且理解单位的相对比例是线性的。这意味着,特别是输入的总和可能不等于1,值概率(您可能输入为5)。

tf.nn.softmax只生成将softmax function应用于输入张量的结果。 softmax" squhes"输入使sum(input) = 1:它是一种规范化的方式。 softmax的输出形状与输入相同:它只是将值标准化。 softmax 的输出可以解释为概率。

a = tf.constant(np.array([[.1, .3, .5, .9]]))
print s.run(tf.nn.softmax(a))
[[ 0.16838508  0.205666    0.25120102  0.37474789]]

相比之下,tf.nn.softmax_cross_entropy_with_logits在应用softmax函数后计算结果的交叉熵(但它以更加数学上仔细的方式一起完成)。它类似于以下结果:

sm = tf.nn.softmax(x)
ce = cross_entropy(sm)

交叉熵是一个汇总度量:它对元素进行求和。形状tf.nn.softmax_cross_entropy_with_logits张量上[2,5]的输出形状为[2,1](第一维被视为批次)。

如果你想进行优化以最大限度地减少交叉熵 AND ,那么你应该在最后一层之后进行softmaxing,你应该使用tf.nn.softmax_cross_entropy_with_logits而不是自己动手,因为它涵盖了在数学上正确的方式数值不稳定的角落情况。否则,你最终会通过在这里和那里添加小ε来破解它。

编辑2016-02-07: 如果您有单类标签,其中一个对象只能属于一个类,您现在可以考虑使用tf.nn.sparse_softmax_cross_entropy_with_logits,这样就不必将标签转换为密集的单热阵列。在0.6.0版本之后添加了此功能。

答案 1 :(得分:254)

简短版:

假设您有两个张量,其中y_hat包含每个类的计算得分(例如,从y = W * x + b),y_true包含一个热编码的真实标签。

y_hat  = ... # Predicted label, e.g. y = tf.matmul(X, W) + b
y_true = ... # True label, one-hot encoded

如果您将y_hat中的分数解释为非标准化日志概率,则它们 logits

此外,以这种方式计算的总交叉熵损失:

y_hat_softmax = tf.nn.softmax(y_hat)
total_loss = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), [1]))

基本上等于用函数softmax_cross_entropy_with_logits()计算的总交叉熵损失:

total_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))

长版:

在神经网络的输出层,您可能会计算一个数组,其中包含每个训练实例的类别分数,例如计算y_hat = W*x + b。作为一个例子,我在下面创建了一个y_hat作为2 x 3数组,其中行对应于训练实例,列对应于类。所以这里有2个训练实例和3个课程。

import tensorflow as tf
import numpy as np

sess = tf.Session()

# Create example y_hat.
y_hat = tf.convert_to_tensor(np.array([[0.5, 1.5, 0.1],[2.2, 1.3, 1.7]]))
sess.run(y_hat)
# array([[ 0.5,  1.5,  0.1],
#        [ 2.2,  1.3,  1.7]])

请注意,这些值未规范化(即行不要加1)。为了对它们进行归一化,我们可以应用softmax函数,该函数将输入解释为非标准化的对数概率(又名 logits )并输出归一化的线性概率。

y_hat_softmax = tf.nn.softmax(y_hat)
sess.run(y_hat_softmax)
# array([[ 0.227863  ,  0.61939586,  0.15274114],
#        [ 0.49674623,  0.20196195,  0.30129182]])

完全理解softmax输出的含义非常重要。下面我展示了一个更清楚地代表上面输出的表格。可以看出,例如,训练实例1的概率为" Class 2"是0.619。每个训练实例的类概率都是标准化的,因此每行的总和为1.0。

                      Pr(Class 1)  Pr(Class 2)  Pr(Class 3)
                    ,--------------------------------------
Training instance 1 | 0.227863   | 0.61939586 | 0.15274114
Training instance 2 | 0.49674623 | 0.20196195 | 0.30129182

所以现在我们有每个训练实例的类概率,我们可以采用每行的argmax()来生成最终的分类。从上面,我们可以生成训练实例1属于" Class 2"和训练实例2属于" Class 1"。

这些分类是否正确?我们需要根据训练集中的真实标签进行衡量。您将需要一个单热编码的y_true数组,其中行也是训练实例,列是类。下面我创建了一个示例y_true one-hot数组,其中训练实例1的真实标签是" Class 2"培训实例2的真正标签是" Class 3"。

y_true = tf.convert_to_tensor(np.array([[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]]))
sess.run(y_true)
# array([[ 0.,  1.,  0.],
#        [ 0.,  0.,  1.]])

y_hat_softmax中的概率分布是否接近y_true中的概率分布?我们可以使用cross-entropy loss来衡量错误。

Formula for cross-entropy loss

我们可以逐行计算交叉熵损失并查看结果。下面我们可以看到训练实例1的损失为0.479,而训练实例2的损失则高达1.200。这个结果是有道理的,因为在上面的例子中,y_hat_softmax表明训练实例1的最高概率是" Class 2",它匹配y_true中的训练实例1;然而,对训练实例2的预测显示出" Class 1"的概率最高,这与真正的等级" Class 3"不匹配。

loss_per_instance_1 = -tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1])
sess.run(loss_per_instance_1)
# array([ 0.4790107 ,  1.19967598])

我们真正想要的是所有训练实例的总损失。所以我们可以计算:

total_loss_1 = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1]))
sess.run(total_loss_1)
# 0.83934333897877944

使用softmax_cross_entropy_with_logits()

我们可以使用tf.nn.softmax_cross_entropy_with_logits()函数计算总交叉熵损失,如下所示。

loss_per_instance_2 = tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true)
sess.run(loss_per_instance_2)
# array([ 0.4790107 ,  1.19967598])

total_loss_2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))
sess.run(total_loss_2)
# 0.83934333897877922

请注意,total_loss_1total_loss_2会产生基本相同的结果,最后的数字会有一些小的差异。但是,您也可以使用第二种方法:它只需少一行代码并累积较少的数值误差,因为softmax是在softmax_cross_entropy_with_logits()内完成的。

答案 2 :(得分:44)

tf.nn.softmax计算通过softmax层的前向传播。在计算模型输出的概率时,可以在模型的求值期间使用它。

tf.nn.softmax_cross_entropy_with_logits计算softmax图层的成本。它仅在培训期间使用。

logits是非标准化日志概率输出模型(将softmax标准化应用于它们之前输出的值)。

答案 3 :(得分:3)

以上答案对问题提供了足够的描述。

除此之外,Tensorflow还优化了应用激活功能的操作,然后使用自己的激活后跟成本函数计算成本。因此,最好使用tf.nn.softmax_cross_entropy()而不是tf.nn.softmax(); tf.nn.cross_entropy()

您可以在资源密集型模型中找到它们之间的显着差异。

答案 4 :(得分:2)

Tensorflow 2.0兼容的答案dgastackoverflowuser2010的解释非常详细地介绍了Logits及其相关功能。

所有这些功能在 Tensorflow 1.x 中使用时都可以正常工作,但是如果您将代码从 1.x (1.14, 1.15, etc) 迁移到 { {1}} ,使用这些功能会导致错误。

因此,如果我们从 2.x (2.0, 2.1, etc..) 进行迁移,那么会为社区的利益指定所有功能的2.0兼容调用。

1.x中的功能

  1. 1.x to 2.x
  2. tf.nn.softmax
  3. tf.nn.softmax_cross_entropy_with_logits

从1.x迁移到2.x时的相应功能

  1. tf.nn.sparse_softmax_cross_entropy_with_logits
  2. tf.compat.v2.nn.softmax
  3. tf.compat.v2.nn.softmax_cross_entropy_with_logits

有关从1.x到2.x迁移的更多信息,请参阅此Migration Guide

答案 5 :(得分:2)

学期的数学动机

当我们希望输出一个约束在 0 和 1 之间的值,但我们的模型架构输出不受约束的值时,我们可以添加一个归一化层来强制执行此操作。

常见的选择1sigmoid函数。在二元分类中,这通常是逻辑函数,而在多类任务中,这是多项逻辑函数(又名 softmax2)。

如果我们想将新最后一层的输出解释为“概率”,那么(暗示)我们 sigmoid 的无约束输入必须是 inverse-sigmoid(概率)。在逻辑情况下,这相当于我们概率的log-odds(即odds的对数)又名logit

这就是为什么 softmax 的参数在 Tensorflow 中被称为 logits - 因为假设 softmax 是模型中的最后一层,并且输出 p< /em> 被解释为概率,该层的输入 x 可以解释为 logit:

<头>
enter image description here enter image description here

广义术语

在机器学习中,倾向于概括从数学/统计/计算机科学借来的术语,因此在 Tensorflow 中 logit(通过类比)被用作许多归一化函数输入的同义词。


  1. 虽然它具有很好的特性,例如易于微分和前面提到的概率解释,但它有点arbitrary
  2. softmax 可能更准确地称为 softargmax,因为它是一个 smooth approximation of the argmax function

答案 6 :(得分:1)

softmax的所有内容都是logit,这就是J. Hinton一直在Coursera视频中重复的内容。

答案 7 :(得分:0)

我肯定要强调的一件事是logit仅仅是原始输出,通常是最后一层的输出。这也可以是负值。如果我们将其用于“交叉熵”评估,如下所述:

-tf.reduce_sum(y_true * tf.log(logits))

然后它将无法正常工作。由于-ve的日志未定义。 因此,使用o softmax激活将克服此问题。

这是我的理解,如果我错了,请纠正我。