括号引起的矩阵乘法的执行时间差

时间:2017-08-22 10:16:32

标签: numpy python-3.6 parentheses

鉴于两个1D numpy数组ab

N = 100000
a = np.randn(N)
b = np.randn(N)

为什么以下两个表达式之间存在相当大的执行时间差异:

# expression 1
c = a @ a * b @ b

# expression 2
c = (a @ a) * (b @ b)

使用Jupyter Notebook的%timeit魔法,我得到以下结果:

  

%timeit a @ a * b @ b

     

每循环223μs±6.97μs(平均值±标准偏差,7次运行,每次1000次循环)

  

%timeit(a @ a)*(b @ b)

     

每回路17.4μs±27.3 ns(平均值±标准偏差,7次运行,每次100000次循环)

1 个答案:

答案 0 :(得分:2)

在两个版本中,您都会使用长度为N的向量的两个点积。但是,另外第一个解决方案执行N次乘法,而第二个解决方案只需要一次。

a @ a * b @ b相当于((a @ a) * b) @ b

aa = a @ a  # N multiplications and additions -> scalar
aab = aa * b  # N multiplications -> vector
aabb = aab @ b  # N multiplications and additions -> scalar

(a @ a) * (b @ b)相当于

aa = a @ a  # N multiplications and additions -> scalar
bb = b @ b  # N multiplications and additions -> scalar
aabb = aa * bb  # 1 multiplication -> scalar

矩阵乘法性能取决于如何设置括号的事实是众所周知的。存在通过利用这一事实来优化matrix chain multiplication的算法。

更新:正如我刚才所了解的那样,numpy具有优化多重矩阵乘法的功能:numpy.linalg.multidot