为什么scipy multivariate_normal pdf ufunc不起作用?

时间:2015-07-30 17:07:23

标签: python-2.7 statistics scipy

为什么此代码不会针对nsamples=3运行?它运行nsamples in (1,2)

from scipy.stats import multivariate_normal
import numpy as np

mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)

nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.meshgrid(x, y)
print state
rv.logpdf(state)

以下是错误消息:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-595249b070ac> in <module>()
      4 state = np.meshgrid(x, y)
      5 print state
----> 6 rv.logpdf(state)

/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in logpdf(self, x)
    518         x = _process_quantiles(x, self.dim)
    519         out = self._mnorm._logpdf(x, self.mean, self.cov_info.U,
--> 520                                   self.cov_info.log_pdet, self.cov_info.rank)
    521         return _squeeze_output(out)
    522 

/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
    377 
    378         """
--> 379         dev = x - mean
    380         maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
    381         return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)

ValueError: operands could not be broadcast together with shapes (2,3,3) (2,) 

似乎库中存在一个错误:我认为x需要滚动轴或mean需要重新整形。

1 个答案:

答案 0 :(得分:1)

np.meshgrid返回2D数组的元组:

In [124]: np.meshgrid(x, y)
Out[124]: 
[array([[-1.,  0.,  1.],
        [-1.,  0.,  1.],
        [-1.,  0.,  1.]]), array([[-2., -2., -2.],
        [ 0.,  0.,  0.],
        [ 2.,  2.,  2.]])]

rv.logpdf需要一个2元组列表,或者一个最后一个轴长度为2的数组:

In [128]: state
Out[128]: 
array([[-1., -2.],
       [-1.,  0.],
       [-1.,  2.],
       [ 0., -2.],
       [ 0.,  0.],
       [ 0.,  2.],
       [ 1., -2.],
       [ 1.,  0.],
       [ 1.,  2.]])

In [129]: rv.logpdf(state)
Out[129]: 
array([-3.53102425, -3.03102425, -3.53102425, -3.03102425, -2.53102425,
       -3.03102425, -3.53102425, -3.03102425, -3.53102425])

In [131]: rv.logpdf(state.reshape(3,3,-1))
Out[131]: 
array([[-3.53102425, -3.03102425, -3.53102425],
       [-3.03102425, -2.53102425, -3.03102425],
       [-3.53102425, -3.03102425, -3.53102425]])

因此,您可以使用np.meshgrid代替itertools.product

state = np.array(list(IT.product(x, y)))

或者,为了在xy较大时提高速度,请使用pv's cartesian function

from scipy.stats import multivariate_normal
import numpy as np
import itertools as IT

mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)

nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.array(list(IT.product(x, y)))
logpdf = rv.logpdf(state.reshape(nsamples, nsamples, -1))
print(logpdf)

产量

[[-3.53102425 -3.03102425 -3.53102425]
 [-3.03102425 -2.53102425 -3.03102425]
 [-3.53102425 -3.03102425 -3.53102425]]