我想要在同一个图的子图中用imshow
显示几个矩阵。它们都具有相同数量的列但行数不同。我想:
imshow
aspect=1
imshow
效果
sharex
(这一共意味着子图的高度反映了矩阵中不同数量的行)。我尝试使用gridspec
(通过gridspec_kw
的{{1}}参数),但plt.subplots
和sharex
的组合导致部分矩阵被切断,除非我手动调整窗口大小。例如:
aspect=1
根据每个矩阵中的行数,我可以猜测应使它们全部可见,但它不起作用的球场图形尺寸:
import numpy as np
import matplotlib.pyplot as plt
# fake data
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
# initialize figure
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
gridspec_kw=dict(height_ratios=nrows))
for ix, d in enumerate(data):
ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)
注意所有3个子图的顶部和底部行是如何被部分切除的(它在中间部分最容易看到)但是在顶部和底部的图形边缘有大量的多余空白:
使用figsize = (foo.shape[1], sum(nrows))
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
gridspec_kw=dict(height_ratios=nrows),
figsize=figsize)
for ix, d in enumerate(data):
ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)
也无法解决问题;它使子图变得太大(注意轴脊与图像之间每个子图的顶部/底部的间隙):
有没有办法让tight_layout
和imshow
在这里和谐地工作?
答案 0 :(得分:2)
我刚刚发现了ImageGrid
,这很有效。完整的例子:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
ax = axs[ix]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)