Matplotlib imshow()给出水平翻转的密度图

时间:2018-10-11 05:23:52

标签: python matplotlib

我对绘制双变量正态混合图很感兴趣。

def func(x):
    cat = tfd.Categorical(probs=np.array([.5, .5],dtype=NP_DTYPE))
    comps = [tfd.MultivariateNormalDiag(loc=np.array([-5.0, -5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1),
         tfd.MultivariateNormalDiag(loc=np.array([5.0, 5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1)]

    mix = tfd.Mixture(cat=cat, components=comps)
    return mix.prob(x)

这是两个双变量正态的混合。一个中心位于[5,5],另一个中心位于[-5,-5];都具有对角线协方差矩阵,沿对角线的距离为0.1。每个都具有相同的混合重量。

我的绘图代码是

# make these smaller to increase the resolution
dx, dy = 0.1, 0.1

x = np.arange(-10.0, 10.0, dx)
y = np.arange(-10.0, 10.0, dy)

X, Y = np.meshgrid(x, y)
Z = np.concatenate((X.reshape(-1,1),Y.reshape(-1,1)),axis=1)


extent = np.min(x), np.max(x), np.min(y), np.max(y)
fig = plt.figure(frameon=True)



Z2 = tf.log(func(Z) + 1e-6)
Z2 = sess.run(Z2)
Z2 = Z2.reshape(int(np.sqrt(Z2.shape[0])),int(np.sqrt(Z2.shape[0])))

im2 = plt.imshow(Z2, cmap=plt.cm.viridis, alpha=.9, interpolation='bilinear',
                 extent=extent)
plt.colorbar()
plt.show()

之所以我将网格栅格化是因为我想实现通用的密度图,以便可以绘制任何复杂的2D分布。 (任意二维分布,其输入形状为[N,D]; N是点的数量,D是每个点的尺寸)

但是,这给出了一个奇怪的图

enter image description here

由于高温区域水平翻转了图,因此应该在[5,5]和[-5,-5]左右

有任何人解决这个问题吗? (因为imshow()是一种黑盒,并且我需要密度函数采用特定形式的输入;我不知道如何解决此问题)

1 个答案:

答案 0 :(得分:0)

origin="lower"添加到imshow图的参数中。

默认值为"upper",这对于绘制图像通常很有意义,因为它们通常从左上角开始。但是,对于基于矩阵的绘图,通常需要使用origin="lower",否则您的绘图将被翻转。

im2 = plt.imshow(Z2, cmap=plt.cm.viridis, alpha=.9, interpolation='bilinear',
                 extent=extent,origin='lower')
plt.colorbar()
plt.show()

enter image description here