我对绘制双变量正态混合图很感兴趣。
def func(x):
cat = tfd.Categorical(probs=np.array([.5, .5],dtype=NP_DTYPE))
comps = [tfd.MultivariateNormalDiag(loc=np.array([-5.0, -5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1),
tfd.MultivariateNormalDiag(loc=np.array([5.0, 5.0],dtype=NP_DTYPE), scale_diag=tf.ones(2,dtype=DTYPE)*.1)]
mix = tfd.Mixture(cat=cat, components=comps)
return mix.prob(x)
这是两个双变量正态的混合。一个中心位于[5,5],另一个中心位于[-5,-5];都具有对角线协方差矩阵,沿对角线的距离为0.1。每个都具有相同的混合重量。
我的绘图代码是
# make these smaller to increase the resolution
dx, dy = 0.1, 0.1
x = np.arange(-10.0, 10.0, dx)
y = np.arange(-10.0, 10.0, dy)
X, Y = np.meshgrid(x, y)
Z = np.concatenate((X.reshape(-1,1),Y.reshape(-1,1)),axis=1)
extent = np.min(x), np.max(x), np.min(y), np.max(y)
fig = plt.figure(frameon=True)
Z2 = tf.log(func(Z) + 1e-6)
Z2 = sess.run(Z2)
Z2 = Z2.reshape(int(np.sqrt(Z2.shape[0])),int(np.sqrt(Z2.shape[0])))
im2 = plt.imshow(Z2, cmap=plt.cm.viridis, alpha=.9, interpolation='bilinear',
extent=extent)
plt.colorbar()
plt.show()
之所以我将网格栅格化是因为我想实现通用的密度图,以便可以绘制任何复杂的2D分布。 (任意二维分布,其输入形状为[N,D]; N是点的数量,D是每个点的尺寸)
但是,这给出了一个奇怪的图
由于高温区域水平翻转了图,因此应该在[5,5]和[-5,-5]左右
有任何人解决这个问题吗? (因为imshow()是一种黑盒,并且我需要密度函数采用特定形式的输入;我不知道如何解决此问题)