我正在尝试实现以下softplus功能:
log(1 + exp(x))
我已尝试将math / numpy和float64作为数据类型,但只要x
过大(例如x = 1000
),结果就是inf
。
你能帮我解决一下如何用大数字成功处理这个功能吗?
答案 0 :(得分:8)
有一种关系可以使用:
log(1+exp(x)) = log(1+exp(x)) - log(exp(x)) + x = log(1+exp(-x)) + x
因此,从数学上讲,安全的实施方式应该是:
log(1+exp(-abs(x))) + max(x,0)
这既适用于数学函数又适用于numpy函数(例如:np.log,np.exp,np.abs,np.maximum)。
答案 1 :(得分:5)
由于x>30
我们有log(1+exp(x)) ~= log(exp(x)) = x
,因此一个简单的稳定实现
def safe_softplus(x, limit=30):
if x>limit:
return x
else:
return np.log(1.0 + np.exp(x))
实际上| log(1+exp(30)) - 30 | < 1e-10
,因此此实现使错误小于1e-10
并且永远不会溢出。特别是对于x = 1000,此近似值的误差将远小于float64分辨率,因此甚至无法在计算机上进行测量。
答案 2 :(得分:0)
我使用此代码在数组中工作
def safe_softplus(x):
inRanges = (x < 100)
return np.log(1 + np.exp(x*inRanges))*inRanges + x*(1-inRanges)
答案 3 :(得分:0)
我目前使用的(效率稍低但干净且矢量化):
def safe_softplus(x, limit=30):
return np.where(x>limit, x, np.log1p(np.exp(x)))