Python pandas滚动意味着保留索引和列

时间:2017-04-02 17:33:45

标签: python pandas

我有一个关于NBA比赛的熊猫DataFrame统计数据。以下是客队的数据样本:

                                  away_team  away_efg  away_drb  away_score
date
2000-10-31 19:00:00    Los Angeles Clippers     0.522      74.4          94
2000-10-31 19:00:00         Milwaukee Bucks     0.434      63.0          93
2000-10-31 19:30:00  Minnesota Timberwolves     0.523      73.8         106
2000-10-31 19:30:00       Charlotte Hornets     0.605      77.1         106
2000-10-31 19:30:00     Seattle SuperSonics     0.429      73.1          88

away_score列之外还有更多数字列,以及主队的类似列。

我想要的是,对于每一行,将数字列(除了得分)替换为前三个观察的平均值,由团队划分。通过执行以下操作,我几乎得到我想要的东西:

home_df.groupby("team").apply(lambda x: x.rolling(window=3).mean())

这会返回,例如

>>> home_avg[home_avg["team"]=="Utah Jazz"].head()
         3par        ast   blk        drb       efg       ftr        orb
0         NaN        NaN   NaN        NaN       NaN       NaN        NaN
50        NaN        NaN   NaN        NaN       NaN       NaN        NaN
81   0.146667  71.600000   9.4  74.666667  0.512000  0.347667  25.833333

考虑到这一点,以及

>>> home_df[home_df["team"]=="Utah Jazz"].head()
      3par   ast   blk   drb    efg    ftr   orb   stl       team   tov   trb
0    0.118  76.7   7.1  64.7  0.535  0.365  25.6  11.5  Utah Jazz  10.8  42.9
50   0.100  63.9   9.1  80.5  0.536  0.414  27.6   2.2  Utah Jazz  20.2  58.6
81   0.222  74.2  12.0  78.8  0.465  0.264  24.3   7.3  Utah Jazz  13.9  50.0
122  0.119  81.8  11.3  75.0  0.515  0.642  25.0  12.2  Utah Jazz  21.8  52.5
135  0.129  76.7  17.8  75.9  0.650  0.400  37.9   5.7  Utah Jazz  18.8  62.7

表明它在计算均值时包含当前行。我想避免这种情况。更具体地说,第81行的所需输出将是全部NaN s(因为还没有三个游戏),并且第122行的3par列中的条目将是.146667 (行0,50和81中该列的平均值)。

所以,我的问题是,如何排除滚动平均值计算中的当前行?

1 个答案:

答案 0 :(得分:3)

您可以在shift处使用移动给定金额的索引,以使滚动窗口使用除当前值之外的最后三个值:

# create dummy data frame with numeric values
df = pd.DataFrame({"numeric_col": np.random.randint(0, 100, size=5)})
print(df)

    numeric_col
0   66
1   60
2   74
3   41
4   83

df["mean"] = df["numeric_col"].shift(1).rolling(window=3).mean()
print(df)

    numeric_col     mean
0   66              NaN
1   60              NaN
2   74              NaN
3   41              66.666667
4   83              58.333333

因此,请将您的应用功能更改为lambda x: x.shift(1).rolling(window=3).mean(),以使其在您的具体示例中有效。