由于一段应该执行以下操作的python代码,我面临性能问题:
我有2个数组A和B,其中包含未排序的值,我想构建一个新的数组C,它将包含以下每个索引:
C[i]= sum(flag*B[k] for k so that flag = A[k]<=A[i])
我是以两种方式完成的:
1)非常直接的方式:
M = len(A)
C = np.zeros(M)
for i in xrange(M):
value = A[i]
flag = A <= value
C[i] = np.sum(flag * B)
2)尝试使用numpy sort函数:
indices_sorted = np.argsort(A)
C_sort = np.zeros(M)
for i in xrange(M):
index = np.where(indices_sorted==i)
for k in xrange(index[0][0]+1):
C_sort[i] += B[indices_sorted[k]]
结果是第一个对于5000个元素阵列来说要快得多(因子40-50)。
我没想到第二次尝试那么糟糕,第一次尝试也不够快......
你们能给我一个更好的方法吗?
提前致谢。
答案 0 :(得分:3)
假设A
和B
是相同形状的一维数组,您可以通过将A
扩展到2D数组然后进行比较来使用broadcasting
,从而基本上以矢量化方式比较每个元素与每个其他元素。然后,使用B
执行元素乘法,再次使broadcasting
起作用。最后沿第二轴求和以获得最终输出。实现看起来像这样 -
C = ((A <= A[:,None])*B).sum(1)
您可以使用elementwise multiplication and summing
使用matrix-multiplication
模拟C = (A <= A[:,None]).dot(B)
的相同行为,以获得更有效的解决方案,例如 -
row,col = np.nonzero(A <= A[:,None])
C = np.bincount(row,np.take(B,col))
以下是基于使用np.dot
建立索引并使用np.take
进行计数的另一种方法 -
2D
对于大型数据,创建(A <= A[:,None]
matrix-multiplication
掩码的内存开销可能会抵消性能。因此,作为对现有循环代码的优化,您可以引入np.sum(flag * B)
来替换元素乘法和求和。因此,flag.dot(B)
可以替换为M = len(A)
C = np.empty(M)
for i in xrange(M):
C[i] = (A <= A[i]).dot(B)
。引入一些其他优化技巧,你会有一个像这样的修改版本 -
idx = A.argsort()
C = B[idx].cumsum()[idx.argsort()]
终于!以下是np.bincount
-
A
以下是对其工作原理和原因的快速解释:
您正在执行元素比较,然后根据比较结果对B中的元素求和。现在,如果C
是排序数组,则输出cumsum
基本上是B
B
版本的A
。因此,对于通用未排序的情况,您需要按cumsum
的argsort对def org_app(A,B):
M = len(A)
C = np.zeros(M)
for i in range(M):
value = A[i]
flag = A <= value
C[i] = np.sum(flag * B)
return C
def sum_based(A,B):
return ((A <= A[:,None])*B).sum(1)
def dot_based(A,B):
return (A <= A[:,None]).dot(B)
def bincount_based(A,B):
row,col = np.nonzero(A <= A[:,None])
return np.bincount(row,np.take(B,col))
def org_app_modified(A,B):
M = len(A)
C = np.empty(M)
for i in xrange(M):
C[i] = (A <= A[i]).dot(B)
return C
def cumsum_trick(A,B):
idx = A.argsort()
return B[idx].cumsum()[idx.argsort()]
进行排序,对其执行In [212]: # Inputs
...: N = 5000
...: A = np.random.rand(N)
...: B = np.random.rand(N)
...:
In [213]: %timeit org_app(A,B)
...: %timeit sum_based(A,B)
...: %timeit dot_based(A,B)
...: %timeit bincount_based(A,B)
...: %timeit org_app_modified(A,B)
...: %timeit cumsum_trick(A,B)
...:
1 loops, best of 3: 266 ms per loop
1 loops, best of 3: 411 ms per loop
1 loops, best of 3: 322 ms per loop
1 loops, best of 3: 1.01 s per loop
10 loops, best of 3: 196 ms per loop
1000 loops, best of 3: 835 µs per loop
,最后根据原始未排序的顺序重新排列元素。
运行时测试
定义方法 -
{{1}}
设置输入和时间 -
{{1}}