使用scipy truncnorm的numpy广播

时间:2017-09-25 23:58:31

标签: python python-2.7 numpy scipy numpy-broadcasting

我想评估分位数的不同值和未截断平均值的不同值的单侧截断正态分布。为了提高效率,我想使用numpy广播而不是Python循环。

对于最小可重复的示例,假设我要评估的三个分位数为[3.0, 2.0, 1.0],相应的未截断平均值为[6.0, 5.0, 4.0],下限为1.5,并且未截断的标准差为3.0

单独评估这些按预期工作。如果我跑

import numpy as np
from scipy.stats import truncnorm
print truncnorm.logpdf(3.0, a=(1.5-6.0)/3.0, b=np.inf, loc=6.0, scale=3.0)
print truncnorm.logpdf(2.0, a=(1.5-5.0)/3.0, b=np.inf, loc=5.0, scale=3.0)
print truncnorm.logpdf(1.0, a=(1.5-4.0)/3.0, b=np.inf, loc=4.0, scale=3.0)

我得到了

-2.44840736626
-2.3878150686
-inf

(最后一个值为-inf,因为1.0小于截止值。一次使用numpy广播两个值也可以按预期工作。如果我跑

print truncnorm.logpdf(
    np.array([3.0, 2.0]),
    a=(1.5-np.array([6.0, 5.0]))/3.0,
    b=np.inf,
    loc=np.array([6.0, 5.0]),
    scale=3.0
)
print truncnorm.logpdf(
    np.array([2.0, 1.0]),
    a=(1.5-np.array([5.0, 4.0]))/3.0,
    b=np.inf,
    loc=np.array([5.0, 4.0]),
    scale=3.0
)

我得到了

[-2.44840737 -2.38781507]
[-2.38781507        -inf]

但是,如果我尝试通过运行一次评估三个值:

print truncnorm.logpdf(
    np.array([3.0, 2.0, 1.0]),
    a=(1.5-np.array([6.0, 5.0, 4.0]))/3.0,
    b=np.inf,
    loc=np.array([6.0, 5.0, 4.0]),
    scale=3.0
)

我收到错误:

Traceback (most recent call last):
  File "truncnorm_error.py", line 25, in <module>
    scale=3.0
  File "C:\Python27\lib\site-packages\scipy\stats\_distn_infrastructure.py", line 1701, in logpdf
    place(output, cond, self._logpdf(*goodargs) - log(scale))
  File "C:\Python27\lib\site-packages\scipy\stats\_continuous_distns.py", line 4853, in _logpdf
    return _norm_logpdf(x) - self._logdelta
ValueError: operands could not be broadcast together with shapes (2,) (3,)

我错过了什么?我使用的是Python 2.7,numpy 1.13和scipy 0.19。

2 个答案:

答案 0 :(得分:0)

这不起作用的原因,因为logpdf检查分位数以确保它们大于截止值。如果你的值小于截断值,显然它适用于大小1和2,但不适用于3.所以这可能是错误。

如果提供的值大于截断值,则可以正常工作。例如,这是有效的,我将分位数改为1.0到1.6:

print truncnorm.logpdf(
    np.array([3.0, 2.0, 1.6]),
    a=(1.5-np.array([6.0, 5.0, 4.0]))/3.0,
    b=np.inf,
    loc=np.array([6.0, 5.0, 4.0]),
    scale=3.0)

答案 1 :(得分:0)

谢谢,所有。与此同时,我自己动手:

x

它不优雅,我确定它有问题,但它似乎适用于我的目的(例如,它正确地广播untruncated_mean和{{1}}的矢量参数。