使用for循环对plt.subplots()返回的Axes对象进行插值

时间:2018-12-01 04:20:38

标签: matplotlib seaborn

我正在尝试使用for循环通过以下代码填充Axes中的每个subplots

df = sns.load_dataset('iris')
cols = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
# plotting
fig, ax = plt.subplots(2,2)
for ax_row in range(2):
    for ax_col in range(2):
        for col in cols:
            sns.distplot(df[col], ax=ax[ax_row][ax_col])

但是我在所有四个轴上都有相同的图。我应该如何更改才能使其正常工作?

1 个答案:

答案 0 :(得分:1)

问题是for col in cols:,您在其中遍历每个子图的所有列。相反,您需要在一个子图中一次绘制一列。为此,一种方法是使用索引i并在循环遍历子图时不断对其进行更新。下面是答案:

import seaborn as sns
df = sns.load_dataset('iris')
cols = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
# plotting
fig, ax = plt.subplots(2,2, figsize=(8, 6))
i = 0
for ax_row in range(2):
    for ax_col in range(2):
        ax_ = sns.distplot(df[cols[i]], ax=ax[ax_row][ax_col])
        i += 1
plt.tight_layout()  

编辑:使用enumerate

fig, ax = plt.subplots(2,2, figsize=(8, 6))
for i, axis in enumerate(ax.flatten()):        
    ax_ = sns.distplot(df[cols[i]], ax=axis)
plt.tight_layout()  

编辑2:在enumerate上使用cols

fig, axes = plt.subplots(2,2, figsize=(8, 6))
for i, col in enumerate(cols):        
    ax_ = sns.distplot(df[col], ax=axes.flatten()[i])
plt.tight_layout()  

enter image description here