如何加快此Numba矩阵乘法的速度

时间:2019-06-19 14:35:01

标签: python numpy matrix-multiplication numba

我尝试使用numba重现矩阵分解。这里的代码:

import numpy as np
import timeit
from numba import jit, float64, prange


@jit('float64[:,:](float64[:,:],float64[:,:])', parallel=True, nopython=True)
def matmul(A, B):
    C = np.zeros((A.shape[0], B.shape[1]))
    for i in prange(A.shape[0]):
        for j in prange(B.shape[1]):
            for k in range(A.shape[0]):
                C[i,j] = C[i,j] + A[i,k]*B[k,j]
    return C



if __name__ == '__main__':
    m_size = 1000
    num_loops = 10
    A = np.random.rand(m_size, m_size)
    B = np.random.rand(m_size, m_size)

    # Numpy
    start = timeit.default_timer()
    for i in range(num_loops):
        A.dot(B)
    stop = timeit.default_timer()
    execution_time = stop - start
    print("Numpy Executed in ", execution_time)


    # Numba
    start = timeit.default_timer()
    for i in range(num_loops):
        matmul(A, B)
    stop = timeit.default_timer()
    execution_time = stop - start
    print("Numba Executed in ", execution_time) 

这是输出:

Numpy Executed in  0.713342247006949
Numba Executed in  17.631791604988393

related post中,numba和numpy的性能非常接近。 我做错了什么,如何改善matmul函数的性能?

0 个答案:

没有答案