如何加快Pandas多级数据帧的移位?

时间:2013-07-01 09:29:12

标签: performance pandas

我试图通过第一个索引组来移动Pandas数据帧列数据。这是演示代码:

 In [8]: df = mul_df(5,4,3)

In [9]: df
Out[9]:
                 COL000  COL001  COL002
STK_ID RPT_Date
A0000  B000     -0.5505  0.7445 -0.3645
       B001      0.9129 -1.0473 -0.5478
       B002      0.8016  0.0292  0.9002
       B003      2.0744 -0.2942 -0.7117
A0001  B000      0.7064  0.9636  0.2805
       B001      0.4763  0.2741 -1.2437
       B002      1.1563  0.0525 -0.7603
       B003     -0.4334  0.2510 -0.0105
A0002  B000     -0.6443  0.1723  0.2657
       B001      1.0719  0.0538 -0.0641
       B002      0.6787 -0.3386  0.6757
       B003     -0.3940 -1.2927  0.3892
A0003  B000     -0.5862 -0.6320  0.6196
       B001     -0.1129 -0.9774  0.7112
       B002      0.6303 -1.2849 -0.4777
       B003      0.5046 -0.4717 -0.2133
A0004  B000      1.6420 -0.9441  1.7167
       B001      0.1487  0.1239  0.6848
       B002      0.6139 -1.9085 -1.9508
       B003      0.3408 -1.3891  0.6739

In [10]: grp = df.groupby(level=df.index.names[0])

In [11]: grp.shift(1)
Out[11]:
                 COL000  COL001  COL002
STK_ID RPT_Date
A0000  B000         NaN     NaN     NaN
       B001     -0.5505  0.7445 -0.3645
       B002      0.9129 -1.0473 -0.5478
       B003      0.8016  0.0292  0.9002
A0001  B000         NaN     NaN     NaN
       B001      0.7064  0.9636  0.2805
       B002      0.4763  0.2741 -1.2437
       B003      1.1563  0.0525 -0.7603
A0002  B000         NaN     NaN     NaN
       B001     -0.6443  0.1723  0.2657
       B002      1.0719  0.0538 -0.0641
       B003      0.6787 -0.3386  0.6757
A0003  B000         NaN     NaN     NaN
       B001     -0.5862 -0.6320  0.6196
       B002     -0.1129 -0.9774  0.7112
       B003      0.6303 -1.2849 -0.4777
A0004  B000         NaN     NaN     NaN
       B001      1.6420 -0.9441  1.7167
       B002      0.1487  0.1239  0.6848
       B003      0.6139 -1.9085 -1.9508

此处附有mul_df()代码:How to speed up Pandas multilevel dataframe sum?

现在我希望grp.shift(1)获得一个大数据帧。

In [1]: df = mul_df(5000,30,400)
In [2]: grp = df.groupby(level=df.index.names[0])
In [3]: timeit grp.shift(1)
1 loops, best of 3: 5.23 s per loop

5.23s太慢了。如何加快速度?

(我的电脑配置为:Pentium Dual-Core T4200@2.00GHZ,3.00GB RAM,WindowXP,Python 2.7.4,Numpy 1.7.1,Pandas 0.11.0,numexpr 2.0.1,Anaconda 1.5.0(32)位))

4 个答案:

答案 0 :(得分:5)

如何移动整个DataFrame对象,然后将每个组的第一行设置为NaN?

dfs = df.shift(1)
dfs.iloc[df.groupby(level=0).size().cumsum()[:-1]] = np.nan

答案 1 :(得分:4)

问题是shift操作不是cython优化的,所以它涉及到python的回调。与此相比:

In [84]: %timeit grp.shift(1)
1 loops, best of 3: 1.77 s per loop

In [85]: %timeit grp.sum()
1 loops, best of 3: 202 ms per loop

为此添加了一个问题:https://github.com/pydata/pandas/issues/4095

答案 2 :(得分:0)

类似的问题和补充的答案适用于任何方向和幅度的转变:pandas: setting last N rows of multi-index to Nan for speeding up groupby with shift

代码(包括测试设置)是:

#
# the function to use in apply
#
def replace_shift_overlap(grp,col,N,value):
    if (N > 0):
        grp[col][:N] = value
    else:
        grp[col][N:] = value
    return grp


length = 5
groups = 3
rng1 = pd.date_range('1/1/1990', periods=length, freq='D')
frames = []
for x in xrange(0,groups):
    tmpdf = pd.DataFrame({'date':rng1,'category':int(10000000*abs(np.random.randn())),'colA':np.random.randn(length),'colB':np.random.randn(length)})
    frames.append(tmpdf)
df = pd.concat(frames)

df.sort(columns=['category','date'],inplace=True)
df.set_index(['category','date'],inplace=True,drop=True)
shiftBy=-1
df['tmpShift'] = df['colB'].shift(shiftBy)

# 
# the apply
#
df = df.groupby(level=0).apply(replace_shift_overlap,'tmpShift',shiftBy,np.nan)

# Yay this is so much faster.
df['newColumn'] = df['tmpShift'] / df['colA']
df.drop('tmpShift',1,inplace=True)

编辑:请注意,初始排序真的会影响到它的有效性。所以在某些情况下,原始答案更有效。

答案 3 :(得分:0)

尝试一下:

import numpy as np
import pandas as pd
df = pd.DataFrame({'A': [10, 20, 15, 30, 45,43,67,22,12,14,54],
                   'B': [13, 23, 18, 33, 48, 1,7, 56,66,45,32],
                   'C': [17, 27, 22, 37, 52,77,34,21,22,90,8],
                   'D':    ['a','a','a','a','b','b','b','c','c','c','c']
                   })
df
#>      A   B   C  D
#> 0   10  13  17  a
#> 1   20  23  27  a
#> 2   15  18  22  a
#> 3   30  33  37  a
#> 4   45  48  52  b
#> 5   43   1  77  b
#> 6   67   7  34  b
#> 7   22  56  21  c
#> 8   12  66  22  c
#> 9   14  45  90  c
#> 10  54  32   8  c
def groupby_shift(df, col, groupcol, shift_n, fill_na = np.nan):
    '''df: dataframe
       col: column need to be shifted 
       groupcol: group variable
       shift_n: how much need to shift
       fill_na: how to fill nan value, default is np.nan 
    '''
    rowno = list(df.groupby(groupcol).size().cumsum()) 
    lagged_col = df[col].shift(shift_n)
    na_rows = [i for i in range(shift_n)] 
    for i in rowno:
        if i == rowno[len(rowno)-1]: 
            continue 
        else:
            new = [i + j for j in range(shift_n)]
            na_rows.extend(new) 
    na_rows = list(set(na_rows)) 
    na_rows = [i for i in na_rows if i <= len(lagged_col) - 1] 
    lagged_col.iloc[na_rows] = fill_na
    return lagged_col
    
df['A_lag_1'] = groupby_shift(df, 'A', 'D', 1)
df
#>      A   B   C  D  A_lag_1
#> 0   10  13  17  a      NaN
#> 1   20  23  27  a     10.0
#> 2   15  18  22  a     20.0
#> 3   30  33  37  a     15.0
#> 4   45  48  52  b      NaN
#> 5   43   1  77  b     45.0
#> 6   67   7  34  b     43.0
#> 7   22  56  21  c      NaN
#> 8   12  66  22  c     22.0
#> 9   14  45  90  c     12.0
#> 10  54  32   8  c     14.0