共享X轴时,移除imshow周围的框

时间:2017-08-20 22:31:40

标签: python matplotlib

当我尝试堆叠多个imshow元素时,在垂直轴上会出现一些额外的空白区域,并且标题看起来与其他图形太接近了。

我认为这两个问题都是由sharex=True引起的,但我不知道如何解决这些问题。

fig.tight_layout()几乎解决了这个问题,但它与侧面的颜色条不兼容,并使一些方块比其他方块小。

生成图像的代码是

# Values is a [(ndarray, string)]
fig, axes = plt.subplots(len(values), sharex=True)
for ax, (value, plot_name) in zip(axes, values):
    im = ax.imshow(value, vmax=1.0, vmin=0.0)
    ax.set_title(plot_name)

# (Hack) Apply on the last one
plt.xticks(range(values.shape[1]), ticks, rotation=90)
plt.colorbar(im, ax=axes.ravel().tolist())

fig.savefig(output_name, bbox_inches="tight")

示例图片是: enter image description here

2 个答案:

答案 0 :(得分:2)

gridspec_kw={"hspace": 0.8}参数添加到plt.subplots构造函数使它对我有用。这控制了我认为的子图之间的垂直空间

ticks = ["blah" for i in range(17)]
# Values is a [(ndarray, string)]
values = [(np.random.randn(3,17), "Title") for i in range(3)]
fig, axes = plt.subplots(len(values), sharex=True, gridspec_kw={"hspace": 0.8})
for ax, (value, plot_name) in zip(axes, values):
    im = ax.imshow(value, vmax=1.0, vmin=0.0)
    ax.set_title(plot_name)

# (Hack) Apply on the last one
plt.xticks(range(values[0][0].shape[1]), ticks, rotation=90)
plt.colorbar(im, ax=axes.ravel().tolist())

plt.show()

enter image description here

答案 1 :(得分:1)

不幸的是,使用"equal"时,图表的方面无法设置为sharex=True。可能有两种解决方案:

不共享轴

共享轴并不是必需的,因为所有子图都具有相同的尺寸。因此,我们的想法是不要共享任何轴,而只需删除上图的勾选标签。

import matplotlib.pyplot as plt
import numpy as np

values = [np.random.rand(3,10) for i in range(3)]

fig, axes = plt.subplots(len(values))
for i, (ax, value) in enumerate(zip(axes, values)):
    im = ax.imshow(value, vmax=1.0, vmin=0.0)
    ax.set_title("title")
    ax.set_xticks(range(value.shape[1]))
    if i != len(axes)-1:
        ax.set_xticklabels([])
    else:
        plt.setp(ax.get_xticklabels(), rotation=90)

plt.colorbar(im, ax=axes.ravel().tolist())


plt.show()

enter image description here

使用ImageGrid

使用ImageGrid模块中的mpl_toolkits.axes_grid1专门为相同方面的图提供网格。它可以如下使用。这里的一个主要优点是颜色条将自动与子图的大小相同。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

values = [np.random.rand(3,10) for i in range(3)]

axes = ImageGrid(plt.figure(), 111,
                 nrows_ncols=(3,1),
                 axes_pad=0.3,
                 share_all=True,
                 cbar_location="right",
                 cbar_mode="single",
                 cbar_size="2%",
                 cbar_pad=0.15,
                 label_mode = "L"
                 )

for i, (ax, value) in enumerate(zip(axes, values)):
    im = ax.imshow(value, vmax=1.0, vmin=0.0)
    ax.set_title("title")
    ax.set_xticks(range(value.shape[1]))

plt.setp(ax.get_xticklabels(), rotation=90)

ax.cax.colorbar(im)

plt.show()

enter image description here