我得到了尺寸为(nx2x2)的矩阵,所以设了(3x2x2)的矩阵M:
[[[ 1. 5.]
[ 2. 4.]]
[[ 5. 25.]
[10. 20.]]
[[ 5. 25.]
[10. 20.]]]
我想对每个2x2矩阵进行乘积运算(点积)。换句话说,我想做以下事情:
其中Mj与python中的M[j,:,:]
相同。最简单的方法是使用这样的for循环:
prod=np.identity(2)
for temp in M:
prod=np.dot(prod,temp)
但是我不知道是否有矢量化的方法(我敢肯定这可能是一个重复的问题,但是我找不到适合谷歌的问题)
答案 0 :(得分:3)
In [99]: X = np.array([1,5,2,4,5,25,10,20,5,25,10,20]).reshape(3,2,2)
In [100]: X
Out[100]:
array([[[ 1, 5],
[ 2, 4]],
[[ 5, 25],
[10, 20]],
[[ 5, 25],
[10, 20]]])
In [101]: prod=np.identity(2)
In [102]: for t in X:
...: prod=np.dot(prod,t)
...:
In [103]: prod
Out[103]:
array([[1525., 3875.],
[1550., 3850.]])
使用mutmul
运算符:
In [104]: X[0]@X[1]@X[2]
Out[104]:
array([[1525, 3875],
[1550, 3850]])
In [105]:
链接点函数:
In [106]: np.linalg.multi_dot(X)
Out[106]:
array([[1525, 3875],
[1550, 3850]])
检查其文档和代码,以查看其如何尝试优化计算。
How is numpy multi_dot slower than numpy.dot?
numpy large matrix multiplication optimize
Multiply together list of matrices in Numpy
如果我多次展开X
:
In [147]: X.shape
Out[147]: (48, 2, 2)
一个点产品的成本:
In [148]: timeit X[0]@X[1]
5.08 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
其中48个:
In [149]: 5*48
Out[149]: 240
将其与迭代评估链点的成本进行比较:
In [150]: %%timeit
...: Y = np.eye(X.shape[1]).astype(X.dtype)
...: for i in X:
...: Y = Y@i
...:
227 µs ± 5.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
因此,链接只需要评估所有必需的点就不会花费任何费用。
我尝试了一个执行嵌套点的递归函数,并且没有任何改进。
除非您可以在数学上提出一种顺序评估点的替代方法,否则我怀疑是否还有改进的余地。使用cython
之类的方法直接调用基础BLAS
函数可能是唯一的选择。它不会在单个点上有所改善,但可以减少一些调用开销。
答案 1 :(得分:2)
如果您只想提速,这是一个解决方案(我将这台计算机的速度提高了10倍)
import numpy as np
import numba as nb
big_array = np.random.rand(400, 400, 400)
def numpy_method(array3d):
return np.linalg.multi_dot(array3d)
@nb.jit(nopython=True)
def numba_method(array3d):
prod = np.identity(array3d.shape[2])
for i in nb.prange(array3d.shape[0]):
temp = array3d[i, :, :]
prod = np.dot(prod, temp)
return prod
numpy_result = numpy_method(big_array)
numba_result = numba_method(big_array)
print(np.all(np.isclose(numpy_result, numba_result)))
%timeit numpy_method(big_array)
9.66 s ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit numba_method(big_array)
923 ms ± 7.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
这不能很好地扩展到较小的数组,但是当数组大到您说的那样时,使用LLVM编译器编译代码确实会产生明显的变化。
此外,当我使数组变大时,它们都会使float中的可用内存过载,并转换为inf值。我猜测您的矩阵不会遇到此问题。