Pandas:基于另一列的过滤器聚合

时间:2017-02-02 21:26:39

标签: python pandas aggregate

我有一个看起来像这样的数据框

Month   Fruit   Sales
1       Apple   45
1       Bananas 12
3       Apple   6
1       Kiwi    34
12      Melon   12

我正在尝试获取类似这样的数据框

Fruit         Sales (month=1)     Sales (month=2)
Apple         55                  65
Bananas       12                  102
Kiwi          54                  78
Melon         132                 43

现在我有

df=df.groupby(['Fruit']).agg({'Sales':np.sum}).reset_index()

必须有一些方法可以根据“Month”变量过滤agg()中的参数。我只是无法在文档中找到它。有什么帮助吗?

编辑: 谢谢你的解决方案。为了使事情复杂化,我想总结另一个专栏。例如:

Month    Fruit    Sales  Revenue
1       Apple    45     45
1       Bananas  12     12
3       Apple    6      6
1       Kiwi     34     34
12      Melon    12     12

首选输出类似于

            Sales      Revenue
     Fruit   1  3  12  1   3  12
0    Apple  61  6   0  61  6  0
1  Bananas  12  6   0  12  6  0
2     Kiwi  34  0   0  34  0  0
3    Melon   0  0  12  0   0  12

我设法通过df.pivot_table(values=['Sales','Revenue'], index='Fruit', columns=['Month'], aggfunc='np.sum').reset_index()获得此功能,因此我的问题得到了解决。

我尝试使用df.groupby(['Fruit', 'Month'])['Sales','Revenue'].sum().unstack('Month', fill_value=0).rename_axis(None, 1).reset_index()进行同样的操作,但这会引发TypeError。以上操作也可以使用groupby完成吗?

2 个答案:

答案 0 :(得分:4)

要回答更新的问题,您应该做一些不同的事情。 首先group by后面应该是列的元素(Month和Fruit)。然后计算这些组的总和以及unstack之后的DataFrame,它将Fruit列作为索引列。

data = '''
Month    Fruit   Sales  Revenue
1       Apple    45     45
1       Bananas  12     12
1       Apple    16     16
3       Apple    6      6
1       Kiwi     34     34
3       Bananas  6      6
12      Melon    12     12
'''
df = pd.read_csv(StringIO(data), sep='\s+')

df.groupby(['Month', 'Fruit'])\
    .sum()\
    .unstack(level=0)

结果

        Sales            Revenue           
Month      1    3     12      1    3     12
Fruit                                      
Apple    61.0  6.0   NaN    61.0  6.0   NaN
Bananas  12.0  6.0   NaN    12.0  6.0   NaN
Kiwi     34.0  NaN   NaN    34.0  NaN   NaN
Melon     NaN  NaN  12.0     NaN  NaN  12.0

旧答案

使用pivot_table方法:

import pandas as pd
from io import StringIO

data = '''\
Month Fruit  Sales
1       Apple   45
1       Bananas 12
1       Apple   16
3       Apple   6
1       Kiwi    34
3       Bananas 6
12      Melon   12
'''
df = pd.read_csv(StringIO(data), sep='\s+')

df.pivot_table('Sales', index='Fruit', columns=['Month'], aggfunc='sum')

结果:

Month      1    3     12
Fruit                   
Apple    61.0  6.0   NaN
Bananas  12.0  6.0   NaN
Kiwi     34.0  NaN   NaN
Melon     NaN  NaN  12.0

答案 1 :(得分:0)

<强>更新

In [177]: df
Out[177]:
   Month    Fruit  Sales  Revenue
0      1    Apple     45       45
1      1  Bananas     12       12
2      3    Apple      6        6
3      1     Kiwi     34       34
4     12    Melon     12       12

In [178]: df.groupby(['Fruit', 'Month'])[['Sales','Revenue']].sum().unstack('Month', fill_value=0)
Out[178]:
        Sales        Revenue
Month      1  3   12      1  3   12
Fruit
Apple      45  6   0      45  6   0
Bananas    12  0   0      12  0   0
Kiwi       34  0   0      34  0   0
Melon       0  0  12       0  0  12

OLD回答:

或者,您可以使用groupby() + unstack()

In [206]: df.groupby(['Fruit', 'Month'])['Sales'].sum().unstack('Month', fill_value=0) \
     ...:   .rename_axis(None, 1).reset_index()
     ...:
Out[206]:
     Fruit   1  3  12
0    Apple  61  6   0
1  Bananas  12  6   0
2     Kiwi  34  0   0
3    Melon   0  0  12