更快的替代分组/班次

时间:2014-08-14 01:22:45

标签: python pandas

如果您有大量的群组,则“groupby”的一般标题下的文件很慢

nobs =   9999 

df = DataFrame( { 'id' : np.arange(nobs) / 3,
                  'yr' : np.tile( np.array([2007,2008,2009]), nobs/3 ),
                  'val': np.random.randn(nobs) } )

df = df.sort(['id','yr'])

A = df.groupby('id').shift()
B = df.shift()

A是我想要的,但这里需要大约1.5秒,而我的实际用例大约有100倍的观察结果。作为参考,计算A比计算B慢约1000倍。

这是A和B的样子:

In [599]: A.head(6)
Out[599]: 
        val    yr
0       NaN   NaN
1 -0.839041  2007
2 -1.089094  2008
3       NaN   NaN
4 -0.068383  2007
5  0.555293  2008

In [600]: B.head(6)
Out[600]: 
   id       val    yr
0 NaN       NaN   NaN
1   0 -0.839041  2007
2   0 -1.089094  2008
3   0  0.050604  2009
4   1 -0.068383  2007
5   1  0.555293  2008

我喜欢加速A的一般解决方案,但缺席的是,解决方法会很棒。如您所见,B实际上与A相同,除了每个组的第一个值不是真正有效并且需要转换为NaN。它可以用groupby / rank来完成,但是涉及groupby的任何东西似乎都很慢,所以我需要一个非groupby方法。

有没有办法通过排序或索引来复制排名功能?似乎信息必须嵌入那里,但我不知道如何将它提取到一个新的变量。

(编辑添加以下内容)

以下是Jeff提供的链接解决方案(HYRY的原始答案)。我刚刚稍微修改它以使用这里的示例。在我的计算机上,它的运行速度几乎与DSM的解决方案完全相同。

B.iloc[df.groupby('id').size().cumsum()[:-1]] = np.nan

2 个答案:

答案 0 :(得分:4)

这不是最优雅的代码片段,但作为针对您的案例的黑客解决方法,如下所示:

def fast(df, col):
    A = df.sort(col).shift()
    last = A.iloc[-1].copy()
    A.loc[A[col] != A[col].shift(-1)] = np.nan
    A.iloc[-1] = last
    return A.drop(col, axis=1)

def slow(df, col):
    A = df.sort(col).groupby(col).shift()
    return A

给出了

>>> %timeit s = slow(df, "id")
1 loops, best of 3: 2.09 s per loop
>>> %timeit f = fast(df, "id")
100 loops, best of 3: 3.51 ms per loop
>>> slow(df, "id").equals(fast(df, "id"))
True

答案 1 :(得分:0)

尝试一下:

import numpy as np
import pandas as pd
df = pd.DataFrame({'A': [10, 20, 15, 30, 45,43,67,22,12,14,54],
                   'B': [13, 23, 18, 33, 48, 1,7, 56,66,45,32],
                   'C': [17, 27, 22, 37, 52,77,34,21,22,90,8],
                   'D':    ['a','a','a','a','b','b','b','c','c','c','c']
                   })
df
#>      A   B   C  D
#> 0   10  13  17  a
#> 1   20  23  27  a
#> 2   15  18  22  a
#> 3   30  33  37  a
#> 4   45  48  52  b
#> 5   43   1  77  b
#> 6   67   7  34  b
#> 7   22  56  21  c
#> 8   12  66  22  c
#> 9   14  45  90  c
#> 10  54  32   8  c
def groupby_shift(df, col, groupcol, shift_n, fill_na = np.nan):
    '''df: dataframe
       col: column need to be shifted 
       groupcol: group variable
       shift_n: how much need to shift
       fill_na: how to fill nan value, default is np.nan 
    '''
    rowno = list(df.groupby(groupcol).size().cumsum()) 
    lagged_col = df[col].shift(shift_n)
    na_rows = [i for i in range(shift_n)] 
    for i in rowno:
        if i == rowno[len(rowno)-1]: 
            continue 
        else:
            new = [i + j for j in range(shift_n)]
            na_rows.extend(new) 
    na_rows = list(set(na_rows)) 
    na_rows = [i for i in na_rows if i <= len(lagged_col) - 1] 
    lagged_col.iloc[na_rows] = fill_na
    return lagged_col

结果给出:

    
df['A_lag_1'] = groupby_shift(df, 'A', 'D', 1)
df
#>      A   B   C  D  A_lag_1
#> 0   10  13  17  a      NaN
#> 1   20  23  27  a     10.0
#> 2   15  18  22  a     20.0
#> 3   30  33  37  a     15.0
#> 4   45  48  52  b      NaN
#> 5   43   1  77  b     45.0
#> 6   67   7  34  b     43.0
#> 7   22  56  21  c      NaN
#> 8   12  66  22  c     22.0
#> 9   14  45  90  c     12.0
#> 10  54  32   8  c     14.0