防止截止不同维度的imshow子图w / sharex

时间:2016-05-09 21:25:29

标签: python matplotlib subplot imshow

我想要在同一个图的子图中用imshow显示几个矩阵。它们都具有相同数量的列但行数不同。我想:

  1. 使用imshow
  2. 显示所有每个矩阵
  3. 保留aspect=1
  4. imshow效果
  5. 在图中使用sharex
  6. (这一共意味着子图的高度反映了矩阵中不同数量的行)。我尝试使用gridspec(通过gridspec_kw的{​​{1}}参数),但plt.subplotssharex的组合导致部分矩阵被切断,除非我手动调整窗口大小。例如:

    aspect=1

    resulting image with default figure dims

    根据每个矩阵中的行数,我可以猜测使它们全部可见,但它不起作用的球场图形尺寸:

    import numpy as np
    import matplotlib.pyplot as plt
    # fake data
    foo = np.arange(5 * 7).reshape(5, 7)
    bar = np.arange(11 * 7).reshape(11, 7)
    baz = np.arange(3 * 7).reshape(3, 7)
    
    data = [foo, bar, baz]
    nrows = [x.shape[0] for x in data]
    row_labels = np.array([x for x in 'abcdefghijk'])
    col_labels = [x for x in 'ABCDEFG']
    
    # initialize figure
    fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
                            gridspec_kw=dict(height_ratios=nrows))
    
    for ix, d in enumerate(data):
        ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
        _ = ax.imshow(d)
        _ = ax.yaxis.set_ticks(range(d.shape[0]))
        _ = ax.xaxis.set_ticks(range(d.shape[1]))
        _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
        _ = ax.xaxis.set_ticklabels(col_labels)
    

    注意所有3个子图的顶部和底部行是如何被部分切除的(它在中间部分最容易看到)但是在顶部和底部的图形边缘有大量的多余空白:

    resulting image with custom figure dims

    使用figsize = (foo.shape[1], sum(nrows)) fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True, gridspec_kw=dict(height_ratios=nrows), figsize=figsize) for ix, d in enumerate(data): ax = axs[ix % axs.shape[0], ix // axs.shape[0]] _ = ax.imshow(d) _ = ax.yaxis.set_ticks(range(d.shape[0])) _ = ax.xaxis.set_ticks(range(d.shape[1])) _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])]) _ = ax.xaxis.set_ticklabels(col_labels) 也无法解决问题;它使子图变得太大(注意轴脊与图像之间每个子图的顶部/底部的间隙):

    resulting image with custom dims and tight layout

    有没有办法让tight_layoutimshow在这里和谐地工作?

1 个答案:

答案 0 :(得分:2)

我刚刚发现了ImageGrid,这很有效。完整的例子:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
    ax = axs[ix]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

result using ImageGrid