为什么此代码不会针对nsamples=3
运行?它运行nsamples in (1,2)
。
from scipy.stats import multivariate_normal
import numpy as np
mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)
nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.meshgrid(x, y)
print state
rv.logpdf(state)
以下是错误消息:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-11-595249b070ac> in <module>()
4 state = np.meshgrid(x, y)
5 print state
----> 6 rv.logpdf(state)
/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in logpdf(self, x)
518 x = _process_quantiles(x, self.dim)
519 out = self._mnorm._logpdf(x, self.mean, self.cov_info.U,
--> 520 self.cov_info.log_pdet, self.cov_info.rank)
521 return _squeeze_output(out)
522
/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
377
378 """
--> 379 dev = x - mean
380 maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
381 return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)
ValueError: operands could not be broadcast together with shapes (2,3,3) (2,)
似乎库中存在一个错误:我认为x
需要滚动轴或mean
需要重新整形。
答案 0 :(得分:1)
np.meshgrid
返回2D数组的元组:
In [124]: np.meshgrid(x, y)
Out[124]:
[array([[-1., 0., 1.],
[-1., 0., 1.],
[-1., 0., 1.]]), array([[-2., -2., -2.],
[ 0., 0., 0.],
[ 2., 2., 2.]])]
rv.logpdf
需要一个2元组列表,或者一个最后一个轴长度为2的数组:
In [128]: state
Out[128]:
array([[-1., -2.],
[-1., 0.],
[-1., 2.],
[ 0., -2.],
[ 0., 0.],
[ 0., 2.],
[ 1., -2.],
[ 1., 0.],
[ 1., 2.]])
In [129]: rv.logpdf(state)
Out[129]:
array([-3.53102425, -3.03102425, -3.53102425, -3.03102425, -2.53102425,
-3.03102425, -3.53102425, -3.03102425, -3.53102425])
In [131]: rv.logpdf(state.reshape(3,3,-1))
Out[131]:
array([[-3.53102425, -3.03102425, -3.53102425],
[-3.03102425, -2.53102425, -3.03102425],
[-3.53102425, -3.03102425, -3.53102425]])
因此,您可以使用np.meshgrid
代替itertools.product
:
state = np.array(list(IT.product(x, y)))
或者,为了在x
和y
较大时提高速度,请使用pv's cartesian function。
from scipy.stats import multivariate_normal
import numpy as np
import itertools as IT
mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)
nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.array(list(IT.product(x, y)))
logpdf = rv.logpdf(state.reshape(nsamples, nsamples, -1))
print(logpdf)
产量
[[-3.53102425 -3.03102425 -3.53102425]
[-3.03102425 -2.53102425 -3.03102425]
[-3.53102425 -3.03102425 -3.53102425]]