使用for循环添加到子图中

时间:2017-06-19 18:54:52

标签: python pandas matplotlib

我的问题非常简单。我有这个函数,为数据框中的每一列创建一个特定的图形。然而,输出是7个单独的图。我可以制作一个像这样的4x2子图:

f, axarr = plt.subplots(4, 2, figsize = (10, 10))

获取this empty chart

这是我的情节的代码。如何/我应该填写子图而不是返回7个单独的图?包括数据帧的头部以供参考

for index in weights.columns:    
    fig = plt.figure(figsize = (10, 6))
    ax = fig.add_subplot(1, 1, 1)
    ##this gets the bottom axis to be in the middle so you can see 
    ##clearly positive or negative returns
    ax.spines['left'].set_position(('data', 0.0))
    ax.spines['bottom'].set_position(('data', 0.0))

    ax.spines['right'].set_color('none')
    ax.spines['top'].set_color('none')

    ax.set_ylabel('{} One month forward return'.format(index))
    ax.set_xlabel('Percent of Max Exposure')

    ##get the ticks in percentage format

    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: '{:.0%}'.format(x)))

    plt.title('One Month Forward {} Returns Versus Total Exposure'.format(index))
    plt.scatter(weights_scaled[index], forward_returns[index], marker = 'o')

weights_scaled.head()

缺货[415]:             美国股票发展前美国BMI新兴BMI美国房地产
日期
1999-12-31 0.926819 0.882021 0.298016 0.0
2000-01-31 0.463410 0.882021 0.298016 1.0
2000-02-29 0.463410 0.882021 0.298016 0.5
2000-03-31 0.926819 0.882021 0.298016 1.0
2000-04-28 0.926819 0.441010 0.000000 1.0

        Commodity  Gold  US Bonds  

日期
1999-12-31 1.0 1.0 0.051282
2000-01-31 1.0 1.0 0.232785
2000-02-29 1.0 1.0 0.258426
2000-03-31 1.0 0.5 0.025641
2000-04-28 1.0 0.5 0.244795

1 个答案:

答案 0 :(得分:1)

这段代码导致了这个问题:

for index in weights.columns:    
    fig = plt.figure(figsize = (10, 6))
    ax = fig.add_subplot(1, 1, 1)

对于每一列,它都会在该图上创建一个新图形和一个新轴。相反,您应该使用axarr返回您的第一直觉,然后在迭代数据框中的列时,将该数组的一个轴指定给一个变量,在该变量上绘制该列中的数据。

一个虚拟示例如下所示:

# Create array of 8 subplots
f, axarr = plt.subplots(4, 2, figsize=(10,10))

# Create dummy data for my example
new_dict = {c: np.random.randint(low=1, high=10, size=40) for c in ['a','b','c','d','e','f','g']}

df = pd.DataFrame(new_dict)

# Enumerate columns, providing index and column name
for i, col in enumerate(df.columns):

    # Select subplot from list generated earlier
    ax = axarr.flat[i]

    # Select column and plot data on subplot axis
    df[col].hist(ax=ax)

Subplots

编辑代码的相关部分,我想你想要:

for i, col in enumerate(weights.columns):    

    ax = axarr.flat[i]

    ax.set_ylabel('{} One month forward return'.format(col))

    ...

    plt.title('One Month Forward {} Returns Versus Total Exposure'.format(col))
    plt.scatter(weights_scaled[col], forward_returns[col], marker = 'o')