是否有人熟悉增加索引跨度的更快方法,如以下patch(ar, ranges)
所做的那样?换句话说,它计算重叠范围,有点像直方图。也许有一个矢量化的方法已经做了类似的事情?
import numpy as np
ar = np.zeros(10)
ranges = [(1, 2), (4, 7), (6, 8)]
add_this = 1
def patch(ar, ranges):
for start, end in ranges:
ar[start:end] += add_this
return ar
print np.arange(10)*1.
print patch(ar, ranges)
输出:
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[ 0. 1. 0. 0. 1. 1. 2. 1. 0. 0.]
我的问题类似于: - Pandas: Overlapping time count - How to identify the maximum number of overlapping date ranges?
答案 0 :(得分:1)
我认为numpy中没有这样的方法,如何编写一个cython函数来加速计算是非常容易的:
首先创建随机范围:
import numpy as np
N = 1000
idx = np.random.randint(0, N, (100000, 2)).astype(np.uint64)
idx.sort(axis=1)
tidx = [tuple(x) for x in idx.tolist()]
Python for循环:
%%time
a = np.zeros(N)
for s,e in tidx:
a[s:e] += 1
输出:
CPU times: user 459 ms, sys: 1e+03 µs, total: 460 ms
Wall time: 461 ms
定义cython函数:
%%cython
import cython
@cython.wraparound(False)
@cython.boundscheck(False)
def patch(double[::1] a, size_t[:, ::1] idx, double v):
cdef size_t i, j
for i in range(idx.shape[0]):
for j in range(idx[i, 0], idx[i, 1]):
a[j] += v
通过cython计算:
%%time
a2 = np.zeros(N)
patch(a2, idx, 1)
输出:
CPU times: user 18 ms, sys: 0 ns, total: 18 ms
Wall time: 17.7 ms
检查结果:
np.all(a == a2)
输出:
True