python多元正常pdf 3d图

时间:2017-10-31 21:31:25

标签: python

我试图在mnist数据集中显示multivist正常pdf的三维图表。

from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

my0 = np.mean(num_arrays[0],axis=0)
sigma0 = np.identity(784)
p0 = multivariate_normal(my0,sigma0)

X, Y = np.mgrid[-10:10:.1, -10:10:.1]
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, p0_id.pdf(pos),cmap='viridis',linewidth=0)

我收到以下错误消息:

operands could not be broadcast together with shapes (200,200,2) (784,)

我在这里做错了什么?

编辑:完整错误消息

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-70-584e158fe420> in <module>()
     13 fig = plt.figure()
     14 ax = fig.gca(projection='3d')
---> 15 ax.plot_surface(X, Y, p0.pdf(pos),cmap='viridis',linewidth=0)

~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in pdf(self, x)
    608 
    609     def pdf(self, x):
--> 610         return np.exp(self.logpdf(x))
    611 
    612     def rvs(self, size=1, random_state=None):

~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in logpdf(self, x)
    604         x = self._dist._process_quantiles(x, self.dim)
    605         out = self._dist._logpdf(x, self.mean, self.cov_info.U,
--> 606                                  self.cov_info.log_pdet, self.cov_info.rank)
    607         return _squeeze_output(out)
    608 

~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
    452 
    453         """
--> 454         dev = x - mean
    455         maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
    456         return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)

ValueError: operands could not be broadcast together with shapes (200,200,2) (784,) 

1 个答案:

答案 0 :(得分:0)

我试图做一些相同的东西,并且我发现唯一的想法是了解原始形状它计算函数的点对点结果并用{{绘制这一点1}}函数,这是我如何使用python

获得3D高斯形状的示例
Axes3D.scatter()

当meshgrid生成一个多维数组时,我认为import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.stats import multivariate_normal mean = np.array([1., 1.]) cov_matrix = np.array([[2., 0.], [0., 2.]]) fig = plt.figure() ax = fig.gca(projection="3d") x = np.linspace(-3., 3., 20) y = np.linspace(-3., 3., 20) for i in x: for j in y: ax.scatter(i, j, pdf_2d(i, j, multivariate_normal.pdf([i, j], mean=mean_, cov=cov_matrix_)) plt.show() 函数无法对每个元素应用矢量转换。

不幸的是,在我的案例中,我发现只有这样才能了解高斯的形状。