无循环的n个矩阵的乘积运算符

时间:2018-08-22 20:20:06

标签: python numpy

我得到了尺寸为(nx2x2)的矩阵,所以设了(3x2x2)的矩阵M:

[[[ 1.  5.]
  [ 2.  4.]]

 [[ 5. 25.]
  [10. 20.]]

 [[ 5. 25.]
  [10. 20.]]]

我想对每个2x2矩阵进行乘积运算(点积)。换句话说,我想做以下事情:

enter image description here

其中Mj与python中的M[j,:,:]相同。最简单的方法是使用这样的for循环:

prod=np.identity(2)
for temp in M:
    prod=np.dot(prod,temp)

但是我不知道是否有矢量化的方法(我敢肯定这可能是一个重复的问题,但是我找不到适合谷歌的问题)

2 个答案:

答案 0 :(得分:3)

In [99]: X = np.array([1,5,2,4,5,25,10,20,5,25,10,20]).reshape(3,2,2)
In [100]: X
Out[100]: 
array([[[ 1,  5],
        [ 2,  4]],

       [[ 5, 25],
        [10, 20]],

       [[ 5, 25],
        [10, 20]]])
In [101]: prod=np.identity(2)
In [102]: for t in X:
     ...:     prod=np.dot(prod,t)
     ...:     
In [103]: prod
Out[103]: 
array([[1525., 3875.],
       [1550., 3850.]])

使用mutmul运算符:

In [104]: X[0]@X[1]@X[2]
Out[104]: 
array([[1525, 3875],
       [1550, 3850]])
In [105]: 

链接点函数:

In [106]: np.linalg.multi_dot(X)
Out[106]: 
array([[1525, 3875],
       [1550, 3850]])

检查其文档和代码,以查看其如何尝试优化计算。

How is numpy multi_dot slower than numpy.dot?

numpy large matrix multiplication optimize

Multiply together list of matrices in Numpy


如果我多次展开X

In [147]: X.shape
Out[147]: (48, 2, 2)

一个点产品的成本:

In [148]: timeit X[0]@X[1]
5.08 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

其中48个:

In [149]: 5*48
Out[149]: 240

将其与迭代评估链点的成本进行比较:

In [150]: %%timeit
     ...: Y = np.eye(X.shape[1]).astype(X.dtype)
     ...: for i in X:
     ...:     Y = Y@i
     ...:     
227 µs ± 5.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

因此,链接只需要评估所有必需的点就不会花费任何费用。

我尝试了一个执行嵌套点的递归函数,并且没有任何改进。

除非您可以在数学上提出一种顺序评估点的替代方法,否则我怀疑是否还有改进的余地。使用cython之类的方法直接调用基础BLAS函数可能是唯一的选择。它不会在单个点上有所改善,但可以减少一些调用开销。

答案 1 :(得分:2)

如果您只想提速,这是一个解决方案(我将这台计算机的速度提高了10倍)

import numpy as np
import numba as nb

big_array = np.random.rand(400, 400, 400)

def numpy_method(array3d):
    return np.linalg.multi_dot(array3d)

@nb.jit(nopython=True)
def numba_method(array3d):
    prod = np.identity(array3d.shape[2])
    for i in nb.prange(array3d.shape[0]):
        temp = array3d[i, :, :]
        prod = np.dot(prod, temp)

    return prod

numpy_result = numpy_method(big_array)
numba_result = numba_method(big_array)

print(np.all(np.isclose(numpy_result, numba_result)))

%timeit numpy_method(big_array)
9.66 s ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit numba_method(big_array)
923 ms ± 7.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

这不能很好地扩展到较小的数组,但是当数组大到您说的那样时,使用LLVM编译器编译代码确实会产生明显的变化。

此外,当我使数组变大时,它们都会使float中的可用内存过载,并转换为inf值。我猜测您的矩阵不会遇到此问题。

相关问题