使用np.select创建具有多索引数据框的新列

时间:2020-01-13 19:32:36

标签: python pandas

我正在尝试确定我的数据是否越过一条线以及从哪个方向穿过。我使用np.select在单个索引框架上运行它,但是,当我尝试在多索引数据框架上执行相同操作时,会得到所有NaN。

这是我的助手功能:

def calc_crossings_helper(df, line):
# define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
                     (df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))] 
return np.select(line_crossed_cond, [1, -1], default = np.nan)

这样称呼它:

df['Hcross'] = df.groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highBound'))

helper函数返回:

Symbol
AAPL    [nan, nan, -1.0, nan, nan, 1.0, nan, -1.0, nan...
AMZN    [nan, nan, nan, nan, nan, nan, nan, -1.0, nan,...

但为df ['Hcross']列分配了所有NaN

                    Close   Hcross
Symbol Date                    
AAPL   2019-12-02   264.16  NaN
       2019-12-03   259.45  NaN
       2019-12-04   261.74  NaN
       2019-12-05   265.58  NaN
       2019-12-06   270.71  NaN
       2019-12-09   266.92  NaN
       2019-12-10   268.48  NaN
       2019-12-11   270.77  NaN
       2019-12-12   271.46  NaN
       2019-12-13   275.15  NaN
AMZN   2019-12-02  1781.60  NaN
       2019-12-03  1769.96  NaN
       2019-12-04  1760.69  NaN
       2019-12-05  1740.48  NaN
       2019-12-06  1751.60  NaN
       2019-12-09  1749.51  NaN
       2019-12-10  1739.21  NaN
       2019-12-11  1748.72  NaN
       2019-12-12  1760.33  NaN
       2019-12-13  1760.94  NaN

我认为我需要以某种方式展平从辅助函数返回的数组,但是我不知道如何。

1 个答案:

答案 0 :(得分:1)

一个简单的解决方法是返回一个像DataFrame一样被索引的Series。由于np.select返回了与DataFrame相同长度的数组,因此可以进行正确的对齐。

def calc_crossings_helper(df, line):
    # define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
    line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
                         (df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))] 

    return pd.Series(np.select(line_crossed_cond, [1, -1], default = np.nan), index=df.index)

现在gropuby返回的是索引相似的MultiIndex:

df.assign(highbound=265).groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highbound'))

Symbol  Date      
AAPL    2019-12-02    NaN
        2019-12-03    NaN
        2019-12-04    NaN
        2019-12-05   -1.0
        2019-12-06    NaN
        2019-12-09    NaN
        2019-12-10    NaN
        2019-12-11    NaN
        2019-12-12    NaN
        2019-12-13    NaN
AMZN    2019-12-02    NaN
        2019-12-03    NaN
        2019-12-04    NaN
        2019-12-05    NaN
        2019-12-06    NaN
        2019-12-09    NaN
        2019-12-10    NaN
        2019-12-11    NaN
        2019-12-12    NaN
        2019-12-13    NaN

更好的是,鉴于您对DataFrame进行了排序,因此不需要groupby.apply()。您可以在符号级别使用shift添加分组条件,因此只需要一个np.select调用。

line = 'highbound'
# Series b/c there is no pd.Index.shift method
s = pd.Series(df.index.get_level_values('Symbol'), index=df.index)

line_crossed_cond = [(s.eq(s.shift()) 
                      & (df['Close'] < df[line]) 
                      & (df['Close'].shift(1) > df[line].shift(1))),
                     (s.eq(s.shift())
                      & (df['Close'] > df[line]) 
                      & (df['Close'].shift(1) < df[line].shift(1)))]

df['Hcross'] = np.select(line_crossed_cond, [1, -1], default = np.nan)