我有一个子图子图。外部子图由一行乘两列组成,两个内部子图分别由四行和四列组成。假设我想要仅与第一个2x2
内部子图相对应的图例标签。我该怎么做呢?我的尝试如下:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
outerD = dict(nrows=1, ncols=2)
innerD = dict(nrows=2, ncols=2)
D = dict(inner=innerD, outer=outerD)
def initialize_dubsub(D, figsize=None):
""" """
fig = plt.figure(figsize=figsize)
outerG = gridspec.GridSpec(D['outer']['nrows'], D['outer']['ncols'], wspace=0.2, hspace=0.2, width_ratios=[5, 5])
axes = []
for n in range(D['inner']['nrows']):
inner = gridspec.GridSpecFromSubplotSpec(D['inner']['nrows'], D['inner']['ncols'], subplot_spec=outerG[n], wspace=0.25, hspace=0.3, width_ratios=[10, 10], height_ratios=[2, 2])
for m in range(D['inner']['nrows']*D['inner']['ncols']):
ax = plt.Subplot(fig, inner[m])
ax.plot([], [], label='{}x{}'.format(n, m))
ax.set_xticks([])
ax.set_yticks([])
axes.append(ax)
fig.add_subplot(ax)
# handles, labels = axes[:4].get_legend_handles_labels() # first 2x2
# fig.legend(handles=handles, labels=labels, loc='lower center')
fig.legend(loc='lower center', ncol=4, mode='expand')
plt.show()
plt.close(fig)
initialize_dubsub(D)
此代码将输出8
handles
和8
labels
,而我要每个4
。我注释了get_legend_handles_labels()
方法,因为它不适用于数组。
我知道我可以做ax.legend()
,但是我更喜欢使用fig.legend(...)
。我该如何实现?
答案 0 :(得分:1)
您可以仅循环遍历该数组中的轴,然后将这四个子图中的句柄和标签附加到列表中,而不必尝试在所需的子图中的数组上调用.get_legend_handles_labels
。
例如:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
outerD = dict(nrows=1, ncols=2)
innerD = dict(nrows=2, ncols=2)
D = dict(inner=innerD, outer=outerD)
def initialize_dubsub(D, figsize=None):
""" """
fig = plt.figure(figsize=figsize)
outerG = gridspec.GridSpec(D['outer']['nrows'], D['outer']['ncols'], wspace=0.2, hspace=0.2, width_ratios=[5, 5])
axes = []
for n in range(D['inner']['nrows']):
inner = gridspec.GridSpecFromSubplotSpec(D['inner']['nrows'], D['inner']['ncols'], subplot_spec=outerG[n], wspace=0.25, hspace=0.3, width_ratios=[10, 10], height_ratios=[2, 2])
for m in range(D['inner']['nrows']*D['inner']['ncols']):
ax = plt.Subplot(fig, inner[m])
ax.plot([], [], label='{}x{}'.format(n, m))
ax.set_xticks([])
ax.set_yticks([])
axes.append(ax)
fig.add_subplot(ax)
handles, labels = [], []
for ax in axes[:4]:
handles_, labels_ = ax.get_legend_handles_labels()
handles += handles_
labels += labels_
fig.legend(handles=handles, labels=labels, loc='lower center')
#fig.legend(loc='lower center', ncol=4, mode='expand')
plt.show()
plt.close(fig)
initialize_dubsub(D)
答案 1 :(得分:1)
尝试更换
ax.plot([], [], label='{}x{}'.format(n, m))
作者
ax.plot([], [], label=('' if n==0 else '_') + '{}x{}'.format(n, m))
如果我了解您的设置正确...