matplotlib中的axes.flat做什么?

时间:2017-10-21 11:40:54

标签: python matplotlib

我见过使用matplotlib的各种程序,它们使用axes.flat函数,就像这段代码一样:

for i, ax in enumerate(axes.flat):

这是做什么的?

2 个答案:

答案 0 :(得分:16)

让我们看一个最小的例子,我们使用plt.subplots创建一些轴,另见this question

import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=2,nrows=3, sharex=True, sharey=True)

for i, ax in enumerate(axes.flat):
    ax.scatter([i//2+1, i],[i,i//3])

plt.show()

在这里,axes是一个numpy轴,

print(type(axes))
> <type 'numpy.ndarray'>
print(axes.shape)
> (3L, 2L)

axes.flat不是函数,它是numpy.ndarray的一个属性:numpy.ndarray.flat

  

ndarray.flat   阵列上的一维迭代器   这是一个numpy.flatiter实例,它的作用类似于Python的内置迭代器对象,但不是它的子类。

示例:

import numpy as np

a = np.array([[2,3],
              [4,5],
              [6,7]])

for i in a.flat:
    print(i)

会打印数字2 3 4 5 6 7

作为数组的迭代器,您可以使用它来遍历3x2轴上的所有轴,

for i, ax in enumerate(axes.flat):

对于每次迭代,它将从该数组中产生下一个轴,这样您可以轻松地在单个循环中绘制所有轴。

另一种方法是使用axes.flatten(),其中flatten()是numpy数组的方法。它不是迭代器,而是返回数组的扁平化版本:

for i, ax in enumerate(axes.flatten()):

两者之间没有区别。但是,迭代器实际上并不会创建一个新数组,因此可能会稍快一些(尽管在matplotlib轴对象的情况下这永远不会显着)。

flat1 = [ax for ax in axes.flat]
flat2 = axes.flatten()
print(flat1 == flat2)
> [ True  True  True  True  True  True]

迭代一个扁平版本的axes数组有一个优点,就是你可以保存一个循环,而不是分别迭代行和列的简单方法,

for row in axes:
    for ax in row:
        ax.scatter(...)

答案 1 :(得分:-1)

fig, ax = plt.subplots(3, 3, figsize=())
ax = ax.flatten()
for i, col in enumerate(columns):
    sns.distplot(d2[col], ax=ax[i])
plt.tight_layout()