所以我有一个由数百万行交易数据组成的数据集,我试图应用过滤器,如下面的代码所示。函数trade_quant_filter
查找异常值,然后将所有异常值的索引添加到列表中以便稍后处理。
def trim_moments(arr, alpha):
np.sort(arr)
n = len(arr)
k = int(round(n*float(alpha))/2)
return np.mean(arr[k+1:n-k]), np.std(arr[k+1:n-k])
def trade_quant_filter(dataset, alpha, window, gamma):
radius = int(round(window /2))
bad_values = []
for count, row in dataset.iterrows():
if count < radius: # Starting case when we can't have a symmetric radius
local_mean, local_std = trim_moments(
dataset['price'][: count + window].values,alpha)
if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
bad_values.append(count)
elif count > (dataset.shape[0] - radius): # 2
local_mean, local_std = trim_moments(
dataset['price'][count - window: count].values,alpha)
if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
bad_values.append(count)
else:
local_mean, local_std = trim_moments(
dataset['price'][count - radius: count + radius].values,alpha)
if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma: #4
bad_values.append(count)
return bad_values
问题在于我编写的代码太差,无法处理数百万个条目。 150k行大约需要30秒:
stats4 = %prun -r trade_quant_filter(trades_reduced[:150000], alpha,window,gamma)
Ordered by: internal time
List reduced from 154 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
600002 3.030 0.000 3.030 0.000 {method 'reduce' of 'numpy.ufunc' objects}
150000 2.768 0.000 4.663 0.000 _methods.py:73(_var)
1 2.687 2.687 40.204 40.204 <ipython-input-102-4f16164d899e>:8(trade_quant_filter)
300000 1.738 0.000 1.738 0.000 {pandas.lib.infer_dtype}
6000025 1.548 0.000 1.762 0.000 {isinstance}
300004 1.481 0.000 6.937 0.000 internals.py:1804(make_block)
300001 1.426 0.000 13.157 0.000 series.py:134(__init__)
300000 1.033 0.000 3.553 0.000 common.py:1862(_possibly_infer_to_datetimelike)
150000 0.945 0.000 2.562 0.000 _methods.py:49(_mean)
300000 0.902 0.000 12.220 0.000 series.py:482(__getitem__)
有一些因素可以让优化此功能变得更具挑战性:
rolling_mean
这样的函数。 使用Cython和NDarrays作为推荐here似乎是一种可能性,我现在正在学习Cython。
优化此代码的最简单方法是什么?我正在寻找至少10倍的速度提升。