我正在尝试创建一个包含9个子图(3 x 3)的图形。 X和Y轴数据使用groupby来自数据帧。这是我的代码:
fig, axs = plt.subplots(3,3)
for index,cause in enumerate(cause_list):
df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().axs[index].plot()
axs[index].set_title(cause)
plt.show()
但是,它无法产生所需的输出。实际上,它返回了错误。如果我删除axs[index]
之前的plot()
并将其放入plot()
之类的plot(ax=axs[index])
函数中,则它可以工作并产生9个子图,但不显示其中的数据(如该图)。
有人可以指导我在哪里犯错吗?
答案 0 :(得分:1)
您需要展平axs
,否则它是一个二维数组。并且您可以提供绘图功能中的ax,请参见documentation of pandas plot,因此使用示例:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
cause_list = np.arange(9)
df = pd.DataFrame({'CAT':np.random.choice(cause_list,100),
'RYQ':np.random.choice(['A','B','C'],100),
'NO_CONSUMERS':np.random.normal(0,1,100)})
fig, axs = plt.subplots(3,3,figsize=(8,6))
axs = axs.flatten()
for index,cause in enumerate(cause_list):
df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().plot(ax=axs[index])
axs[index].set_title(cause)
plt.tight_layout()