矩阵乘法的最佳numba实现在很大程度上取决于矩阵大小

时间:2016-03-29 16:12:41

标签: python numpy matrix numba

这个问题与我发布的一段时间有关:
Python, numpy, einsum multiply a stack of matrices

我试图理解为什么当乘以一堆矩阵时,我以特定的方式使用Numba时获得的加速。和以前一样,我放入一个(500,201,2,2)数组,在第一个轴的末端乘以(2x2)矩阵(所以500次乘法),得到一个(201,2,2)数组作为结果

这是Python代码:

from numba import jit  # numba 0.24, numpy 1.9.3, python 2.7.11

Arr = rand(500,201,2,2)

def loopMult(Arr):
    ArrMult = Arr[0]
    for i in range(1,len(Arr)):
        ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
    return ArrMult

@jit(nopython=True)
def loopMultJit(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            ArrMult[i] = np.dot(ArrMult[i], Arr[j, i])
    return ArrMult

@jit(nopython=True)
def loopMultJit_2X2(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            x1 = ArrMult[i,0,0] * Arr[j,i,0,0] + ArrMult[i,0,1] * Arr[j,i,1,0]
            y1 = ArrMult[i,0,0] * Arr[j,i,0,1] + ArrMult[i,0,1] * Arr[j,i,1,1]
            x2 = ArrMult[i,1,0] * Arr[j,i,0,0] + ArrMult[i,1,1] * Arr[j,i,1,0]
            y2 = ArrMult[i,1,0] * Arr[j,i,0,1] + ArrMult[i,1,1] * Arr[j,i,1,1]
            ArrMult[i,0,0] = x1
            ArrMult[i,0,1] = y1
            ArrMult[i,1,0] = x2
            ArrMult[i,1,1] = y2
    return ArrMult

A1 = loopMult(Arr)
A2 = loopMultJit(Arr)
A3 = loopMultJit_2X2(Arr)

print np.allclose(A1, A2)
print np.allclose(A1, A3)

%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_2X2(Arr)

这是输出:

True
True
10 loops, best of 3: 40.5 ms per loop
10 loops, best of 3: 36 ms per loop
1000 loops, best of 3: 808 µs per loop

在之前的问题中,接受的答案显示,在没有详细优化的情况下,f2py的加速比为8倍。在这里,使用Numba,我使用numba在einsum循环上获得大约10%的加速,但是如果不是在循环中使用np.dot,我可以获得45倍的加速,我只需手动进行2x2矩阵乘法。为什么是这样?我应该提到我已经使用正确的类型签名实现了这两个jit函数作为guvectorize版本,这基本上提供了相同的加速因子,所以我把它们排除了。迭代超过201,500,2,2矩阵的速度也很快。

1 个答案:

答案 0 :(得分:1)

2评论已经回应说加速只是由于python开销,我认为这是正确的。开销主要是函数调用,但也适用于循环,而np.dot除此之外还有一些额外的开销。我设置了一个天真点产品功能:

@jit(nopython=True)
def dot(mat1, mat2):
    s = 0
    mat = np.empty(shape=(mat1.shape[1], mat2.shape[0]), dtype=mat1.dtype)
    for r1 in range(mat1.shape[0]):
        for c2 in range(mat2.shape[1]):
            s = 0
            for j in range(mat2.shape[0]):
                s += mat1[r1,j] * mat2[j,c2]
            mat[r1,c2] = s
    return mat

然后我设置了函数来乘以数组,一个调用点函数,另一个调用dot函数,这样就可以在没有额外函数调用的情况下执行:

@jit(nopython=True)
def loopMultJit_dot(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            ArrMult[i] = dot(ArrMult[i], Arr[j, i])
    return ArrMult

@jit(nopython=True)
def loopMultJit_dotInternal(Arr):
    ArrMult = np.empty(shape=Arr.shape[1:], dtype=Arr.dtype)
    for i in range(0, Arr.shape[1]):
        ArrMult[i] = Arr[0, i]
        for j in range(1, Arr.shape[0]):
            s = 0.0
            for r1 in range(ArrMult.shape[1]):
                for c2 in range(Arr.shape[3]):
                    s = 0.0
                    for r2 in range(Arr.shape[2]):
                        s += ArrMult[i,r1,r2] * Arr[j,i,r2,c2]
                    ArrMult[i,r1,c2] = s
    return ArrMult

然后我可以运行2个比较:2x2阵列和10x10阵列。通过这些,我可以了解一般的函数调用,特别是np.dot函数调用的惩罚,以及np.dot中BLAS优化的收益:

print "2x2 Time Test:"
Arr = rand(500,201,2,2)
%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_2X2(Arr)
%timeit loopMultJit_dot(Arr)
%timeit loopMultJit_dotInternal(Arr)

print "10x10 Time Test:"
Arr = rand(500,201,10,10)
%timeit loopMult(Arr)
%timeit loopMultJit(Arr)
%timeit loopMultJit_dot(Arr)
%timeit loopMultJit_dotInternal(Arr)

产生:

2x2 Time Test:
10 loops, best of 3: 55.8 ms per loop  # einsum
10 loops, best of 3: 48.7 ms per loop  # np.dot
1000 loops, best of 3: 1.09 ms per loop  # 2x2
10 loops, best of 3: 28.3 ms per loop  # naive dot, separate function
100 loops, best of 3: 2.58 ms per loop  # naive dot internal

10x10 Time Test:
1 loop, best of 3: 499 ms per loop  # einsum
10 loops, best of 3: 91.3 ms per loop  # np.dot
10 loops, best of 3: 170 ms per loop  # naive dot, separate function
10 loops, best of 3: 161 ms per loop  # naive dot internal

我认为带回家的消息是:

    如果你不使用numba,或者需要单行,那么
  • einsum很不错,但对于矩阵乘法,有更快的选择
  • 如果您正在使用小型矩阵,手动操作可以更快,而不是调用单独的函数
  • 对于大型矩阵,有一个原因是发明了BLAS,事实上,在小到10x10的情况下,加速非常明显。