Matplotlib 2 y轴的平均值

时间:2018-09-11 09:51:21

标签: python pandas numpy matplotlib

嗨,我一直在试图弄清楚如何将数据框绘制到图中。我的数据框看起来像这样。

Country | exports 2015 | exports 2016 | Gdp 2015 | GDP 2016| 
  A     |     500      |     600      |   34324  |  23525  | 
  B     |     435      |     335      |    3243  |   2324  |
  C     |     222      |     324      |    23423 |   1233  | 
  D     |     7756     |     9000     |    32424 |  65545  | 

基本上,我想比较所有列的均值,并将它们绘制在一个图表上,其中1 x轴为年,2 y轴为出口和GDP。我只能做一年。基本上我试图获得

         |                                    |
         |                                    |
         |                                    |
 Mean    |                                    |
Exports  |                                    | Mean GDP
         |                                    |
         |                                    |
         |____________________________________|  
              2015                  2016

在绘制图形时我是否需要以某种方式将数据转换为均值?任何建议将不胜感激:)

1 个答案:

答案 0 :(得分:1)

这是使用pandas的一种可能的解决方案。唯一的困难是设置图例位置,因为必须为每个y轴设置标签。请记住,双轴图非常令人困惑。

import pandas as pd
import matplotlib.pyplot as plt


# Stacked input data 
df = pd.DataFrame({'Country': ['A','B', 'C', 'D','A','B', 'C', 'D'],
                   'Year': ['2015','2015','2015','2015','2016','2016','2016','2016'],
                   'Export': [500, 435, 222, 7756,600, 335, 324, 9000],
                   'GDP': [34324, 3243, 23423, 32424,23525, 2324, 1233, 65545]})

# Calculate yearly means
year_means = df.groupby('Year').mean().reset_index()

# Plot the means
ax = year_means.plot(x='Year',
                     y=['Export', 'GDP'],
                     secondary_y= 'GDP',
                     kind= 'bar',
                     mark_right=False)

#Set labels
ax.set_ylabel('Exports')
ax.right_ax.set_ylabel('GDP')

# Adjust legend position
ax.legend(bbox_to_anchor=(1,1), loc="upper left")
ax.right_ax.legend(bbox_to_anchor=(1.2,1), loc="upper left")

plt.show()

enter image description here

编辑:OP没有堆叠的输入数据。解决该问题的一种方法是分别转换变量,然后将其组合为单个框架。下面的解决方案远非最佳方案。

# Not stacked input data 
df = pd.DataFrame({'Country': ['A','B', 'C', 'D'],
                   'Export 2015': [500, 435, 222, 7756],
                   'Export 2016': [600, 335, 324, 9000],
                   'GDP 2015': [34324, 3243, 23423, 32424],
                   'GDP 2016': [23525, 2324, 1233, 65545]})


def stack_variable(df, variable):

    # Get columns of the input dataframe
    names = df.columns

    # Get column names with variable of interest
    var_columns = [name for name in names if variable in name]

    # Extract years
    years = [y.split(variable + ' ')[1] for y in var_columns]

    # Empty dataframe to store results
    stacked_df = pd.DataFrame(columns = [variable, 'Year'])

    # Fill the empty frame
    for idx, col in enumerate(var_columns):
             current = pd.DataFrame({variable: df[col],
                                     'Year': years[idx]})

             stacked_df = stacked_df.append(current)



    return stacked_df


exports = stack_variable(df, 'Export')
gdp = stack_variable(df, 'GDP')

stacked_df = pd.concat([exports, gdp['GDP']], axis=1).reset_index(drop=True)

哪个会返回:

stacked_df

       Export   Year    GDP
    0   500     2015    34324
    1   435     2015    3243
    2   222     2015    23423
    3   7756    2015    32424
    4   600     2016    23525
    5   335     2016    2324
    6   324     2016    1233
    7   9000    2016    65545