我尝试使用numba重现矩阵分解。这里的代码:
import numpy as np
import timeit
from numba import jit, float64, prange
@jit('float64[:,:](float64[:,:],float64[:,:])', parallel=True, nopython=True)
def matmul(A, B):
C = np.zeros((A.shape[0], B.shape[1]))
for i in prange(A.shape[0]):
for j in prange(B.shape[1]):
for k in range(A.shape[0]):
C[i,j] = C[i,j] + A[i,k]*B[k,j]
return C
if __name__ == '__main__':
m_size = 1000
num_loops = 10
A = np.random.rand(m_size, m_size)
B = np.random.rand(m_size, m_size)
# Numpy
start = timeit.default_timer()
for i in range(num_loops):
A.dot(B)
stop = timeit.default_timer()
execution_time = stop - start
print("Numpy Executed in ", execution_time)
# Numba
start = timeit.default_timer()
for i in range(num_loops):
matmul(A, B)
stop = timeit.default_timer()
execution_time = stop - start
print("Numba Executed in ", execution_time)
这是输出:
Numpy Executed in 0.713342247006949
Numba Executed in 17.631791604988393
在related post中,numba和numpy的性能非常接近。 我做错了什么,如何改善matmul函数的性能?