商业稳定的softmax

时间:2017-03-04 18:11:08

标签: python numpy nan scientific-computing softmax

下面有一个数值稳定的方法来计算softmax函数吗? 我在神经网络代码中获得的值变为Nans。

np.exp(x)/np.sum(np.exp(y))

5 个答案:

答案 0 :(得分:28)

softmax exp( x )/ sum(exp( x ))实际上在数值上表现良好。它只有正项,所以我们不必担心重要性的损失,分母至少和分子一样大,所以结果保证在0到1之间。

唯一可能发生的事故是指数过度或不足。 x 的所有元素的单个或下溢的溢出将使输出或多或少无用。

但是通过使用标识softmax( x )= softmax( x + c)来保护任何标量c:减去max(通过减去max)很容易来自 x x 会留下一个只有非正数条目的向量,排除溢出,并且至少有一个零元素排除一个消失的分母(在某些情况下下溢)并非所有条目都是无害的。)

注意:从理论上讲,总和中的灾难性事故是可能的,但你需要一些荒谬的术语,并且是荒谬的不幸。此外,numpy使用相对较强的成对求和。

答案 1 :(得分:5)

Softmax功能容易出现两个问题:溢出下溢

溢出:非常大的数字近似时出现infinity

下溢:当非常小的数字(数字行中接近零)近似(即舍入到)为zero

时,会发生这种情况

要在进行softmax计算时解决这些问题,常见的技巧是将输入向量移位从所有元素中减去其中的最大元素。对于输入向量x,请定义z,以便:

z = x-max(x)

然后取新(稳定)向量z

的softmax

示例:

In [266]: def stable_softmax(x):
     ...:     z = x - max(x)
     ...:     numerator = np.exp(z)
     ...:     denominator = np.sum(numerator)
     ...:     softmax = numerator/denominator
     ...:     return softmax
     ...: 

In [267]: vec = np.array([1, 2, 3, 4, 5])

In [268]: stable_softmax(vec)
Out[268]: array([ 0.01165623,  0.03168492,  0.08612854,  0.23412166,  0.63640865])

In [269]: vec = np.array([12345, 67890, 99999999])

In [270]: stable_softmax(vec)
Out[270]: array([ 0.,  0.,  1.])

有关详细信息,请参阅 Numerical Computation 一书中 deep learning 一章。

答案 2 :(得分:1)

感谢Paul Panzer's解释,但我想知道为什么我们需要减去max(x)。因此,我发现了更详细的信息,并希望它对与我有同样问题的人有所帮助。 请参阅以下链接文章中的“最大减法是什么?”部分。

https://nolanbconaway.github.io/blog/2017/softmax-numpy

答案 3 :(得分:0)

计算softmax函数没有任何问题,就像你的情况一样。问题似乎来自爆炸梯度或您的训练方法的这类问题。使用“裁剪值”或“选择正确的权重初始分布”来关注这些问题。

答案 4 :(得分:0)

扩展@ kmario23的答案以支持一维或二维维numpy数组或列表(如果要通过softmax函数传递一批结果,则很常见):

import numpy as np


def stable_softmax(x):
    z = x - np.max(x, axis=-1, keepdims=True)
    numerator = np.exp(z)
    denominator = np.sum(numerator, axis=-1, keepdims=True)
    softmax = numerator / denominator
    return softmax


test1 = np.array([12345, 67890, 99999999])  # 1D
test2 = np.array([[12345, 67890, 99999999], [123, 678, 88888888]])  # 2D
test3 = [12345, 67890, 999999999]
test4 = [[12345, 67890, 999999999]]

print(stable_softmax(test1))
print(stable_softmax(test2))
print(stable_softmax(test3))
print(stable_softmax(test4))

 [0. 0. 1.]

[[0. 0. 1.]
 [0. 0. 1.]]

 [0. 0. 1.]

[[0. 0. 1.]]