我正在numpy中实现以下功能:
def weak_softmax(a):
b=np.exp(a)
return b/(1+np.sum(b))
数组a
的大小很小,但条目有时可能很大,可能与1000
一样大。因此,由于指数函数溢出,我经常收到以下错误:
a=np.array([1000,1000])
a=weak_softmax(a)
上面的代码返回向量a=[nan nan]
并引发以下警告:
Warning: overflow encountered in exp
是否有任何聪明的方法可以避免此问题,但仍然按预期返回数组b
?这是因为b
的所有条目仅少于一个,我觉得必须有可能使用一些技巧来避免这个问题。
答案 0 :(得分:0)
对于大小合适的exp(c)
,您可以简单地将分子和分母除以相同因子c
。
以下代码使用np.finfo
检查是否可能发生溢出并计算c
。
def modified_soft_max(a, SAFETY=2.0):
mrn = np.finfo(a.dtype).max # largest representable number
thr = np.log(mrn / a.size) - SAFETY
amx = a.max()
if(amx > thr):
b = np.exp(a - (amx-thr))
return b / (np.exp(thr-amx) + b.sum())
else:
b = np.exp(a)
return b / (1.0 + b.sum())