所以我一直尝试(不成功)一段时间来获得三个垂直堆叠的pcolormesh图来共享一个颜色条。各种Google搜索引导我使用matplotlib工具包中的ImageGrid功能,我开始喜欢它(尽管我不太了解它)。
我似乎无法找到的一个特殊功能是能够跨多个列“拉伸”轴对象,就像使用subplot2grid中的colspan参数一样。使用演示代码found here,我就能够生成这个图像了。
正如您所看到的,它当前位于图形空间的中间,两侧的空白量等于我假设的中心轴宽度。简单地说,我希望这三个地块中的每一个都能够在空间的宽度上延伸,同时将颜色条保持在地块的右侧并保持在这个高度。
到目前为止,我尝试过的代码(实际上只是缺少位的示例代码)。任何帮助将不胜感激!
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
def get_demo_image():
from matplotlib.cbook import get_sample_data
f = get_sample_data("axes_grid/bivariate_normal.npy", asfileobj=False)
z = np.load(f)
# z is a numpy array of 15x15
return z, (-3, 4, -4, 3)
if 1:
F = plt.figure(1, (6, 6))
F.clf()
# prepare images
Z, extent = get_demo_image()
ZS = [Z[i::3, :] for i in range(3)]
extent = extent[0], extent[1]/3., extent[2], extent[3]
# demo 2 : shared colorbar
grid2 = ImageGrid(F, 111,
nrows_ncols=(3, 1),
direction="column",
axes_pad=0,
add_all=True,
label_mode="1",
share_all=False,
cbar_location="right",
cbar_mode="single",
cbar_size="10%",
cbar_pad=0.05,
)
grid2[0].set_xlabel("X")
grid2[0].set_ylabel("Y")
vmax, vmin = np.max(ZS), np.min(ZS)
import matplotlib.colors
norm = matplotlib.colors.Normalize(vmax=vmax, vmin=vmin)
for ax, z in zip(grid2, ZS):
im = ax.imshow(z, norm=norm,
origin="lower", extent=extent,
interpolation="nearest")
# With cbar_mode="single", cax attribute of all axes are identical.
ax.cax.colorbar(im)
ax.cax.toggle_label(True)
grid2[0].set_xticks([-2, 0])
grid2[0].set_yticks([-2, 0, 2])
plt.draw()
plt.show()
答案 0 :(得分:1)
ImageGrid
旨在托管具有相同宽高比的图像。如果你不想拥有相同方面的图像,使用通常的子图可能会容易得多。要在网格上查找子图,使用matplotlib.gridspec
会很有帮助。
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(3, 2, width_ratios=[1,.05])
fig = plt.figure()
fig.subplots_adjust(wspace=0.1, hspace=0)
norm=plt.Normalize(0,1)
for i in range(3):
ax = fig.add_subplot(gs[i,0])
im= ax.imshow(np.random.rand(5,36)*0.3+i*0.3, norm=norm, aspect="auto")
cax = fig.add_subplot(gs[:,1])
fig.colorbar(im, cax=cax)
plt.show()