高效的一维数组比较,缩放和求和

时间:2015-11-25 18:21:32

标签: python arrays performance sorting numpy

由于一段应该执行以下操作的python代码,我面临性能问题:

我有2个数组A和B,其中包含未排序的值,我想构建一个新的数组C,它将包含以下每个索引:

C[i]= sum(flag*B[k] for k so that flag = A[k]<=A[i])

我是以两种方式完成的:

1)非常直接的方式:

M = len(A)
C = np.zeros(M)
for i in xrange(M):
    value = A[i]
    flag = A <= value
    C[i] = np.sum(flag * B)

2)尝试使用numpy sort函数:

indices_sorted = np.argsort(A)
C_sort = np.zeros(M)
for i in xrange(M):
    index = np.where(indices_sorted==i)
    for k in xrange(index[0][0]+1):
        C_sort[i] += B[indices_sorted[k]]

结果是第一个对于5000个元素阵列来说要快得多(因子40-50)。

我没想到第二次尝试那么糟糕,第一次尝试也不够快......

你们能给我一个更好的方法吗?

提前致谢。

1 个答案:

答案 0 :(得分:3)

假设AB是相同形状的一维数组,您可以通过将A扩展到2D数组然后进行比较来使用broadcasting,从而基本上以矢量化方式比较每个元素与每个其他元素。然后,使用B执行元素乘法,再次使broadcasting起作用。最后沿第二轴求和以获得最终输出。实现看起来像这样 -

C = ((A <= A[:,None])*B).sum(1)

您可以使用elementwise multiplication and summing使用matrix-multiplication模拟C = (A <= A[:,None]).dot(B) 的相同行为,以获得更有效的解决方案,例如 -

row,col = np.nonzero(A <= A[:,None])
C = np.bincount(row,np.take(B,col))

以下是基于使用np.dot建立索引并使用np.take进行计数的另一种方法 -

2D

对于大型数据,创建(A <= A[:,None] matrix-multiplication掩码的内存开销可能会抵消性能。因此,作为对现有循环代码的优化,您可以引入np.sum(flag * B)来替换元素乘法和求和。因此,flag.dot(B)可以替换为M = len(A) C = np.empty(M) for i in xrange(M): C[i] = (A <= A[i]).dot(B) 。引入一些其他优化技巧,你会有一个像这样的修改版本 -

idx = A.argsort()
C = B[idx].cumsum()[idx.argsort()]

终于!以下是np.bincount -

的获胜者
A

以下是对其工作原理和原因的快速解释:

您正在执行元素比较,然后根据比较结果对B中的元素求和。现在,如果C是排序数组,则输出cumsum基本上是B B版本的A。因此,对于通用未排序的情况,您需要按cumsum的argsort对def org_app(A,B): M = len(A) C = np.zeros(M) for i in range(M): value = A[i] flag = A <= value C[i] = np.sum(flag * B) return C def sum_based(A,B): return ((A <= A[:,None])*B).sum(1) def dot_based(A,B): return (A <= A[:,None]).dot(B) def bincount_based(A,B): row,col = np.nonzero(A <= A[:,None]) return np.bincount(row,np.take(B,col)) def org_app_modified(A,B): M = len(A) C = np.empty(M) for i in xrange(M): C[i] = (A <= A[i]).dot(B) return C def cumsum_trick(A,B): idx = A.argsort() return B[idx].cumsum()[idx.argsort()] 进行排序,对其执行In [212]: # Inputs ...: N = 5000 ...: A = np.random.rand(N) ...: B = np.random.rand(N) ...: In [213]: %timeit org_app(A,B) ...: %timeit sum_based(A,B) ...: %timeit dot_based(A,B) ...: %timeit bincount_based(A,B) ...: %timeit org_app_modified(A,B) ...: %timeit cumsum_trick(A,B) ...: 1 loops, best of 3: 266 ms per loop 1 loops, best of 3: 411 ms per loop 1 loops, best of 3: 322 ms per loop 1 loops, best of 3: 1.01 s per loop 10 loops, best of 3: 196 ms per loop 1000 loops, best of 3: 835 µs per loop ,最后根据原始未排序的顺序重新排列元素。

运行时测试

定义方法 -

{{1}}

设置输入和时间 -

{{1}}