使用所需的迭代优化Pandas代码

时间:2014-08-22 18:24:58

标签: python pandas cython

所以我有一个由数百万行交易数据组成的数据集,我试图应用过滤器,如下面的代码所示。函数trade_quant_filter查找异常值,然后将所有异常值的索引添加到列表中以便稍后处理。

def trim_moments(arr, alpha):
    np.sort(arr)
    n = len(arr)
    k = int(round(n*float(alpha))/2)
    return np.mean(arr[k+1:n-k]), np.std(arr[k+1:n-k])


def trade_quant_filter(dataset, alpha, window, gamma):
    radius = int(round(window /2))
    bad_values = []
    for count, row in dataset.iterrows():
         if count < radius: # Starting case when we can't have a symmetric radius 
            local_mean, local_std = trim_moments(
                    dataset['price'][: count + window].values,alpha)

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
                bad_values.append(count)

         elif count > (dataset.shape[0] - radius): # 2
            local_mean, local_std = trim_moments(
                    dataset['price'][count - window: count].values,alpha) 

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma:
                bad_values.append(count)       

         else:
            local_mean, local_std = trim_moments(
                    dataset['price'][count - radius: count + radius].values,alpha)

            if np.abs(dataset['price'][count] - local_mean) > 3 * local_std + gamma: #4
                bad_values.append(count)

    return bad_values

问题在于我编写的代码太差,无法处理数百万个条目。 150k行大约需要30秒:

stats4 = %prun -r trade_quant_filter(trades_reduced[:150000], alpha,window,gamma)

 Ordered by: internal time
   List reduced from 154 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   600002    3.030    0.000    3.030    0.000 {method 'reduce' of 'numpy.ufunc' objects}
   150000    2.768    0.000    4.663    0.000 _methods.py:73(_var)
        1    2.687    2.687   40.204   40.204 <ipython-input-102-4f16164d899e>:8(trade_quant_filter)
   300000    1.738    0.000    1.738    0.000 {pandas.lib.infer_dtype}
  6000025    1.548    0.000    1.762    0.000 {isinstance}
   300004    1.481    0.000    6.937    0.000 internals.py:1804(make_block)
   300001    1.426    0.000   13.157    0.000 series.py:134(__init__)
   300000    1.033    0.000    3.553    0.000 common.py:1862(_possibly_infer_to_datetimelike)
   150000    0.945    0.000    2.562    0.000 _methods.py:49(_mean)
   300000    0.902    0.000   12.220    0.000 series.py:482(__getitem__)

有一些因素可以让优化此功能变得更具挑战性:

  1. 据我所知,没有办法避免在这里逐行迭代,仍然采用修剪滚动方式和标准偏差。我计划研究如何在Pandas中实现像rolling_mean这样的函数。
  2. 字典的缺乏可用性也使得无法计算修剪的滚动方式和标准偏差,因此我无法将数据帧转换为字典。
  3. 使用Cython和NDarrays作为推荐here似乎是一种可能性,我现在正在学习Cython。

    优化此代码的最简单方法是什么?我正在寻找至少10倍的速度提升。

0 个答案:

没有答案