数组的许多子跨度的快速增量

时间:2013-12-25 08:33:35

标签: python performance numpy pandas scipy

是否有人熟悉增加索引跨度的更快方法,如以下patch(ar, ranges)所做的那样?换句话说,它计算重叠范围,有点像直方图。也许有一个矢量化的方法已经做了类似的事情?

import numpy as np
ar = np.zeros(10)
ranges = [(1, 2), (4, 7), (6, 8)]
add_this = 1

def patch(ar, ranges):
    for start, end in ranges:
        ar[start:end] += add_this
    return ar

print np.arange(10)*1.
print patch(ar, ranges)

输出:

[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9.]
[ 0.  1.  0.  0.  1.  1.  2.  1.  0.  0.]

我的问题类似于:   - Pandas: Overlapping time count   - How to identify the maximum number of overlapping date ranges?

1 个答案:

答案 0 :(得分:1)

我认为numpy中没有这样的方法,如何编写一个cython函数来加速计算是非常容易的:

首先创建随机范围:

import numpy as np

N = 1000
idx = np.random.randint(0, N, (100000, 2)).astype(np.uint64)
idx.sort(axis=1)
tidx = [tuple(x) for x in idx.tolist()]

Python for循环:

%%time
a = np.zeros(N)
for s,e in tidx:
    a[s:e] += 1

输出:

CPU times: user 459 ms, sys: 1e+03 µs, total: 460 ms
Wall time: 461 ms

定义cython函数:

%%cython
import cython

@cython.wraparound(False)
@cython.boundscheck(False)
def patch(double[::1] a, size_t[:, ::1] idx, double v):
    cdef size_t i, j
    for i in range(idx.shape[0]):
        for j in range(idx[i, 0], idx[i, 1]):
            a[j] += v

通过cython计算:

%%time
a2 = np.zeros(N)
patch(a2, idx, 1)

输出:

CPU times: user 18 ms, sys: 0 ns, total: 18 ms
Wall time: 17.7 ms

检查结果:

np.all(a == a2)

输出:

True