如何在numpy中使基于序列的函数更快?

时间:2019-01-23 06:33:26

标签: python performance numpy for-loop vectorization

考虑以下功能:

import numpy as np

a = np.ones(16).reshape(4,4)

def fn(a):
    b = np.array(a)
    for i in range(b.shape[0]):
        for j in range(b.shape[1] - 1):
            b[i][j+1] += b[i][j]
    return b

print(fn(a))

也就是说,对于基于数组中的t+1计算t的通用函数,我可以使其更快吗?我知道有一个np.vectorize,但似乎不适合这种情况。

3 个答案:

答案 0 :(得分:1)

您可以使用cumsum,我认为这会有所帮助。

import numpy as np
import pandas as pd
a = np.ones(16).reshape(4,4)
df =pd.DataFrame(a)
df.cumsum(axis=1)

或者您可以使用np.cumsum()

np.cumsum(a,axis=1)  

答案 1 :(得分:1)

有可能将两个for循环减少到一个for循环,而无需额外的复制开销。

In [86]: a 
Out[86]: 
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

In [87]: b = a.copy() 

In [88]: for col in range(b.shape[1]-1): 
    ...:     b[:, col+1] = np.sum(a[:, :col+2], axis=1) 

In [89]: b
Out[89]: 
array([[1., 2., 3., 4.],
       [1., 2., 3., 4.],
       [1., 2., 3., 4.],
       [1., 2., 3., 4.]])

要使此功能适用于泛型函数,您可以在numpy中寻找等效功能,或者使用numpy操作(矢量化的)实现一个功能。对于您提供的示例,我只是使用numpy.sum()来完成我们的工作。

就性能而言,此方法比在索引级别使用两个for循环进行操作要好得多,尤其是对于较大的数组。在我上面使用的方法中,我们使用列切片。


以下是建议的时机,它们比本机python实现的速度快了 3倍


原生Python:

def fn(a):
    b = np.array(a)
    for i in range(b.shape[0]):
        for j in range(b.shape[1] - 1):
            b[i][j+1] += b[i][j]
    return b

略微向量化:

In [104]: def slightly_vectorized(b): 
     ...:     for col in range(b.shape[1]-1): 
     ...:         b[:, col+1] = np.sum(a[:, :col+2], axis=1) 
     ...:     return b 

In [100]: a = np.ones(625).reshape(25, 25) 

In [101]: %timeit fn(a) 
303 µs ± 2.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [102]: b = a.copy() 

In [103]: %timeit slightly_vectorized(b) 
99.8 µs ± 501 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

答案 2 :(得分:0)

您要寻找的东西称为accumulate,下面是一个示例:

import numpy as np
from itertools import accumulate

def fn(a):
    acc = accumulate(a, lambda prev, row: prev + row)
    return np.array(list(acc))

a = np.arange(16).reshape(4, 4)
print(fn(a))
# [[ 0  1  2  3]
#  [ 4  6  8 10]
#  [12 15 18 21]
#  [24 28 32 36]]

在numpy中没有优化的累加函数,因为实际上不可能以高性能和通用的方式编写累加。 python的实现是通用的,但执行起来很像手工编写的lok。

要获得最佳性能,您可能需要查找或编写所需的特定累加函数的底层实现。您已经提到了numba,也可以研究cython。