优化从大刻度数据生成数据的功能

时间:2018-06-10 11:44:06

标签: python pandas numpy optimization cumsum

我有一个这样的数据框:

df_[['Price', 'Volume', 'Open', 'High', 'Low']]
Out[16]: 
                               Price  Volume  Open   High    Low
datetime                                                        
2016-05-01 22:00:00.334338092  45.90      20  45.9    NaN    NaN
2016-05-01 22:00:00.335312958    NaN       1  45.9    NaN    NaN
2016-05-01 22:00:00.538377726  45.92       1  45.9  45.90  45.90
2016-05-01 22:00:00.590386619  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       3  45.9  45.92  45.90
2016-05-01 22:00:00.590493308  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.591269949  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.591269949  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.591269949  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.707288056  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.719267600  45.92       2  45.9  45.92  45.90
2016-05-01 22:00:00.719267600  45.91       1  45.9  45.92  45.90
2016-05-01 22:00:00.731272008  45.92       1  45.9  45.92  45.90
2016-05-01 22:00:00.731272008  45.91       1  45.9  45.92  45.90
2016-05-01 22:00:00.738358786  45.92       1  45.9  45.92  45.90
(..omitted rows)

从这个数据框中,我定义了一个生成新数据帧的函数:

res
Out[18]: 
                                High    Low  Open  Price  Volume
datetime                                                        
2016-05-01 22:00:00.334338092    NaN    NaN  45.9  45.90      20
2016-05-01 22:00:00.590493308    NaN    NaN  45.9  45.92      11
2016-05-01 22:00:00.731272008  45.92  45.90  45.9  45.91      10
2016-05-01 22:00:00.759276398  45.92  45.90  45.9  45.92      11
2016-05-01 22:00:00.927307727  45.92  45.90  45.9  45.90      36
2016-05-01 22:00:01.054379713  45.92  45.90  45.9  45.89      10
2016-05-01 22:00:01.251324161  45.92  45.89  45.9  45.92      10
2016-05-01 22:00:03.210540968  45.92  45.89  45.9  45.92      11
2016-05-01 22:00:04.450664460  45.92  45.89  45.9    NaN      10
2016-05-01 22:00:07.426789217  45.92  45.89  45.9  45.93      10
2016-05-01 22:00:10.394898254  45.96  45.89  45.9  45.93      10
2016-05-01 22:00:13.359080034  45.96  45.89  45.9  45.92      11
2016-05-01 22:00:17.434346718  45.96  45.89  45.9  45.92      17
2016-05-01 22:00:21.918598002  45.96  45.89  45.9  45.95      10
2016-05-01 22:00:28.587010136  45.96  45.89  45.9  45.94      10
2016-05-01 22:00:32.103168386  45.96  45.89  45.9  45.93      10
2016-05-01 22:01:04.451829835  45.96  45.89  45.9  45.94      14
2016-05-01 22:01:12.662589219  45.96  45.89  45.9  45.94      10
2016-05-01 22:01:17.823792647  45.96  45.89  45.9  45.94      10
2016-05-01 22:01:22.399158701  45.96  45.89  45.9  45.93      11
2016-05-01 22:01:23.511242124  45.96  45.89  45.9  45.92      10
(..omitted rows)

此功能有两个参数:df(dataframe)n(size of Volume, for above, n=10)。 从第一个日期date_1开始,计算累积的音量总和,然后如果音量的累积总和大于或等于n ,则该时刻为date_2。因此,从date_1date_2的此块会聚合到一行,如下所示:

datetime : date_2

Price : price at date_2

Volume : sum of volume from date_1 to date_2

Open : price at date_1

High : max of high from date_1 to date_2

Low : min of low from date-1 to date_2

Do this to end of dataframe. 

我的问题是我的输入数据帧有60000000行。要汇总上述数据,需要花费太多时间。我想优化我的代码以实现功能。这是我的代码:

def tick_to_volume(df, n):
    flag = True
    np_df = np.array(df) #convert to numpy array
    res = pd.DataFrame()
    total_index = 0
    cum_n = 0
    cum_sum = np_df[total_index:,1].cumsum() #cumulative sum of volume
    while(flag):
        cum_n += n
        ix = (cum_sum[total_index:]>=cum_n).argmax() #index when cumulative sum of volume is greater or equal to n
        total_index += ix

        if (ix==0) and (np_df[total_index,4] < n): #for case that all cumulative sum of volume is less than n 
            return res

        cum_n = cum_sum[total_index]                       
        np_df_to_agg = np_df[total_index-ix:(total_index+1), :] #data to be aggregated

        data = {'datetime' : df.index[total_index],
                'Open' : np_df_to_agg[0,2],
                'High' : max(np_df_to_agg[:,3]),
                'Low': min(np_df_to_agg[:,4]),
                'Price' : np_df_to_agg[-1,0],
                'Volume' : sum(np_df_to_agg[:,1])}

        df_to_append = pd.DataFrame([data])
        df_to_append.set_index('datetime', inplace=True)
        res = pd.concat([res, df_to_append])
        total_index += 1

2 个答案:

答案 0 :(得分:1)

这是一种部分矢量化方法。我们的想法是将问题分成两部分。

  1. 确定每个组开始和结束的索引。
  2. 使用自定义逻辑执行groupby + agg
  3. 第二部分很简单。第一部分可以通过一些工作+ numba来有效地完成。

    我们沿df.Volume进行迭代,跟踪累计总和x。每次x超过n时,我们都会标记该行以供将来使用,并设置x = 0。在此之后,我们有一系列指标显示每个组的结束位置。通过一些按摩和照顾第一组/最后一组,我们可以将df.Break转换为一系列ID,然后继续下一步。

    import numpy as np
    from numba import njit
    
    n = 10
    
    
    @njit(fastmath=True)
    def find_breaks(vols, breaks):
        N = len(vols)
        acc = 0
        for i in range(N):
            acc += vols[i]
            if acc >= n:
                acc = 0
            breaks[i] = acc
        return
    
    
    # create a blank column to store group ids
    df["Break"] = np.nan
    # mark points where volumes spill over a threshold
    find_breaks(df.Volume.values, df.Break.values)
    # populate the ids implied by thresholds
    df["Break"] = (df.Break == 0).astype(np.float).replace(0, np.nan).cumsum().bfill()
    # handle the last group
    df["Break"] = df.Break.fillna(df.Break.max() + 1)
    
    # define an aggregator
    aggregator = {
        "Date": "last",
        "Price": "last",
        "Volume": "sum",
        "Open": "first",
        "High": "max",
        "Low": "min",
    }
    
    res = df.groupby("Break").agg(aggregator)
    
    # Date  Price  Volume  Open   High   Low
    # Break
    # 1.0    22:00:00.334338092  45.90      20  45.9    NaN   NaN
    # 2.0    22:00:00.590493308  45.92      11  45.9  45.92  45.9
    # 3.0    22:00:00.731272008  45.91      10  45.9  45.92  45.9
    # 4.0    22:00:00.738358786  45.92       1  45.9  45.92  45.9
    

答案 1 :(得分:0)

重复res = [] while True: res.append(df_to_append) res = pd.concat(res) res.set_index('datetime', inplace=True) 在Pandas和NumPy中表现糟糕。所以不要这样:

data

这样做:

res

您还可以通过将dbutils.fs.mv("dbfs:/tmp/test", "dbfs:/tmp/test2", recurse=True) 存储为元组而不是字典来简化操作。每次都是相同的键,如果你忽略它们,你可以在循环中填充dbutils.fs.mv("dbfs:/tmp/test/test.csv", "dbfs:/tmp/test2/test2.csv") 作为元组列表,避免以后构建许多临时DataFrame和键查找。