我有一个时间序列数据帧,其中有1或0(对/错)。我编写了一个循环遍历所有值为1的行的函数。给定用户定义的整数参数n_hold
,我将从初始行开始将值1设置为n行。
例如,在下面的数据框中,我将循环到第2016-08-05
行。如果为n_hold = 2
,那么我也将2016-08-08
和2016-08-09
都设置为1。
2016-08-03 0
2016-08-04 0
2016-08-05 1
2016-08-08 0
2016-08-09 0
2016-08-10 0
然后生成的df
将是
2016-08-03 0
2016-08-04 0
2016-08-05 1
2016-08-08 1
2016-08-09 1
2016-08-10 0
我遇到的问题是这正在运行10万次,而我当前的解决方案是遍历有行且子集太慢的行。我想知道是否有解决上述问题的方法真的很快。
这是我的(慢速)解决方案,x
是初始信号数据帧:
n_hold = 2
entry_sig_diff = x.diff()
entry_sig_dt = entry_sig_diff[entry_sig_diff == 1].index
final_signal = x * 0
for i in range(0, len(entry_sig_dt)):
row_idx = entry_sig_diff.index.get_loc(entry_sig_dt[i])
if (row_idx + n_hold) >= len(x):
break
final_signal[row_idx:(row_idx + n_hold + 1)] = 1
答案 0 :(得分:2)
完全更改了答案,因为与连续的1
值不同地工作:
说明:
解决方案先将{连续的1
删除,然后将where
与链式布尔掩码通过将ne
与shift
到{{1 }},将!=
用NaN
参数向前填充,最后将ffill
替换回:
limit
计时和比较输出:
0
n_hold = 2
s = x.where(x.ne(x.shift()) & (x == 1)).ffill(limit=n_hold).fillna(0, downcast='int')
时间:
np.random.seed(123)
x = pd.Series(np.random.choice([0,1], p=(.8,.2), size=1000))
x1 = x.copy()
#print (x)
def orig(x):
n_hold = 2
entry_sig_diff = x.diff()
entry_sig_dt = entry_sig_diff[entry_sig_diff == 1].index
final_signal = x * 0
for i in range(0, len(entry_sig_dt)):
row_idx = entry_sig_diff.index.get_loc(entry_sig_dt[i])
if (row_idx + n_hold) >= len(x):
break
final_signal[row_idx:(row_idx + n_hold + 1)] = 1
return final_signal
#print (orig(x))