我正在尝试确定我的数据是否越过一条线以及从哪个方向穿过。我使用np.select在单个索引框架上运行它,但是,当我尝试在多索引数据框架上执行相同操作时,会得到所有NaN。
这是我的助手功能:
def calc_crossings_helper(df, line):
# define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
(df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))]
return np.select(line_crossed_cond, [1, -1], default = np.nan)
这样称呼它:
df['Hcross'] = df.groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highBound'))
helper函数返回:
Symbol
AAPL [nan, nan, -1.0, nan, nan, 1.0, nan, -1.0, nan...
AMZN [nan, nan, nan, nan, nan, nan, nan, -1.0, nan,...
但为df ['Hcross']列分配了所有NaN
Close Hcross
Symbol Date
AAPL 2019-12-02 264.16 NaN
2019-12-03 259.45 NaN
2019-12-04 261.74 NaN
2019-12-05 265.58 NaN
2019-12-06 270.71 NaN
2019-12-09 266.92 NaN
2019-12-10 268.48 NaN
2019-12-11 270.77 NaN
2019-12-12 271.46 NaN
2019-12-13 275.15 NaN
AMZN 2019-12-02 1781.60 NaN
2019-12-03 1769.96 NaN
2019-12-04 1760.69 NaN
2019-12-05 1740.48 NaN
2019-12-06 1751.60 NaN
2019-12-09 1749.51 NaN
2019-12-10 1739.21 NaN
2019-12-11 1748.72 NaN
2019-12-12 1760.33 NaN
2019-12-13 1760.94 NaN
我认为我需要以某种方式展平从辅助函数返回的数组,但是我不知道如何。
答案 0 :(得分:1)
一个简单的解决方法是返回一个像DataFrame一样被索引的Series。由于np.select
返回了与DataFrame相同长度的数组,因此可以进行正确的对齐。
def calc_crossings_helper(df, line):
# define crossing conditions - corresponding choices are [1,-1] to denote direction, otherwise NaN
line_crossed_cond = [(df['Close'] < df[line]) & (df['Close'].shift(1) > df[line].shift(1)),
(df['Close'] > df[line]) & (df['Close'].shift(1) < df[line].shift(1))]
return pd.Series(np.select(line_crossed_cond, [1, -1], default = np.nan), index=df.index)
现在gropuby
返回的是索引相似的MultiIndex:
df.assign(highbound=265).groupby(level=0, group_keys=False).apply(calc_crossings_helper, ('highbound'))
Symbol Date
AAPL 2019-12-02 NaN
2019-12-03 NaN
2019-12-04 NaN
2019-12-05 -1.0
2019-12-06 NaN
2019-12-09 NaN
2019-12-10 NaN
2019-12-11 NaN
2019-12-12 NaN
2019-12-13 NaN
AMZN 2019-12-02 NaN
2019-12-03 NaN
2019-12-04 NaN
2019-12-05 NaN
2019-12-06 NaN
2019-12-09 NaN
2019-12-10 NaN
2019-12-11 NaN
2019-12-12 NaN
2019-12-13 NaN
更好的是,鉴于您对DataFrame进行了排序,因此不需要groupby.apply()
。您可以在符号级别使用shift
添加分组条件,因此只需要一个np.select
调用。
line = 'highbound'
# Series b/c there is no pd.Index.shift method
s = pd.Series(df.index.get_level_values('Symbol'), index=df.index)
line_crossed_cond = [(s.eq(s.shift())
& (df['Close'] < df[line])
& (df['Close'].shift(1) > df[line].shift(1))),
(s.eq(s.shift())
& (df['Close'] > df[line])
& (df['Close'].shift(1) < df[line].shift(1)))]
df['Hcross'] = np.select(line_crossed_cond, [1, -1], default = np.nan)