我需要logit和inverse logit函数,以便logit(inv_logit(n)) == n
。我使用numpy,这就是我所拥有的:
import numpy as np
def logit(p):
return np.log(p) - np.log(1 - p)
def inv_logit(p):
return np.exp(p) / (1 + np.exp(p))
以下是价值观:
print logit(inv_logit(2))
2.0
print logit(inv_logit(10))
10.0
print logit(inv_logit(20))
20.000000018 #well, pretty close
print logit(inv_logit(50))
Warning: divide by zero encountered in log
inf
现在让我们测试负数
print logit(inv_logit(-10))
-10.0
print logit(inv_logit(-20))
-20.0
print logit(inv_logit(-200))
-200.0
print logit(inv_logit(-500))
-500.0
print logit(inv_logit(-2000))
Warning: divide by zero encountered in log
-inf
所以我的问题是:实现这些函数的正确方法是什么,以使logit(inv_logit(n)) == n
要求n
适用于尽可能宽的范围内的任何{{1}}(至少[-1e4; 1e4] )?
而且(我确信这与第一个相关),为什么我的函数与正值相比更稳定?
答案 0 :(得分:11)
使用
1。 bigfloat包支持任意精度浮动点操作。
2。 SymPy 符号数学包。我将举两个例子:
首先,bigfloat:
http://packages.python.org/bigfloat/
这是一个简单的例子:
from bigfloat import *
def logit(p):
with precision(100000):
return log(p)- log(1 -BigFloat(p))
def inv_logit(p):
with precision(100000):
return exp(p) / (1 + exp(p))
int(round(logit(inv_logit(12422.0))))
# gives 12422
int(round(logit(inv_logit(-12422.0))))
# gives -12422
这真的很慢。您可能需要考虑重构问题并分析一些部分。像这样的案例在实际问题中很少见 - 我很好奇你正在处理什么样的问题。
示例安装:
wget http://pypi.python.org/packages/source/b/bigfloat/bigfloat-0.3.0a2.tar.gz
tar xvzf bigfloat-0.3.0a2.tar.gz
cd bigfloat-0.3.0a2
as root:
python setup.py install
关于您的功能使用负值更好的原因。考虑:
>>> float(inv_logit(-15))
3.059022269256247e-07
>>> float(inv_logit(15))
0.9999996940977731
在第一种情况下,浮点数很容易表示该值。移动小数点,以便不需要存储前导零:0.0000 ....在第二种情况下,需要存储所有前导0.999,因此当稍后在logit()中执行1-p时,您需要所有额外的精度来获得精确的结果。
这是符号数学方式(明显更快!):
from sympy import *
def inv_logit(p):
return exp(p) / (1 + exp(p))
def logit(p):
return log(p)- log(1 -p)
x=Symbol('x')
expr=logit(inv_logit(x))
# expr is now:
# -log(1 - exp(x)/(1 + exp(x))) + log(exp(x)/(1 + exp(x)))
# rewrite it: (there are many other ways to do this. read the doc)
# you may want to make an expansion (of some suitable kind) instead.
expr=cancel(powsimp(expr)).expand()
# it is now 'x'
# just evaluate any expression like this:
result=expr.subs(x,123.231)
# result is now an equation containing: 123.231
# to get the float:
result.evalf()
Sympy在http://docs.sympy.org/找到。在ubuntu中,它是通过synaptic找到的。
答案 1 :(得分:6)
有一种方法可以实现这些功能,使它们在各种值中都很稳定,但它根据参数区分不同的情况。
以inv_logit函数为例。你的公式“np.exp(p)/(1 + np.exp(p))”是正确的,但是对于大p会溢出。如果用np.exp(p)除以分子和分母,则得到等价表达式
1. / (1. + np.exp(-p))
不同之处在于,对于大正p而言,这个不会溢出。然而,对于p的大负值,它会溢出。因此,稳定的实施可以如下:
def inv_logit(p):
if p > 0:
return 1. / (1. + np.exp(-p))
elif p <= 0:
np.exp(p) / (1 + np.exp(p))
else:
raise ValueError
这是库LIBLINEAR(以及其他可能的库)中使用的策略。
答案 2 :(得分:1)
您正在遇到IEEE 754双精度浮点数的精度限制。如果您想要更大的范围和更精确的域,则需要使用更高精度的数字和操作。
>>> 1 + np.exp(-37)
1.0
>>> 1 + decimal.Decimal(-37).exp()
Decimal('1.000000000000000085330476257')
答案 3 :(得分:0)
我对Fabian Pedregosa的回答的变体:
def stable_inv_logit(x):
return 0.5*(1. + np.sign(x)*(2./(1. + np.exp(-np.abs(x))) - 1.))
答案 4 :(得分:0)