在numpy中避免exp的溢出错误

时间:2017-01-31 00:25:13

标签: python numpy

我正在numpy中实现以下功能:

def weak_softmax(a):
    b=np.exp(a)
    return b/(1+np.sum(b))

数组a的大小很小,但条目有时可能很大,可能与1000一样大。因此,由于指数函数溢出,我经常收到以下错误:

a=np.array([1000,1000])
a=weak_softmax(a)

上面的代码返回向量a=[nan nan]并引发以下警告:

Warning: overflow encountered in exp

是否有任何聪明的方法可以避免此问题,但仍然按预期返回数组b?这是因为b的所有条目仅少于一个,我觉得必须有可能使用一些技巧来避免这个问题。

1 个答案:

答案 0 :(得分:0)

对于大小合适的exp(c),您可以简单地将分子和分母除以相同因子c

以下代码使用np.finfo检查是否可能发生溢出并计算c

def modified_soft_max(a, SAFETY=2.0):
    mrn = np.finfo(a.dtype).max # largest representable number
    thr = np.log(mrn / a.size) - SAFETY
    amx = a.max()
    if(amx > thr):
        b = np.exp(a - (amx-thr))
        return b / (np.exp(thr-amx) + b.sum())
    else:
        b = np.exp(a)
        return b / (1.0 + b.sum())