numpy大矩阵乘法优化

时间:2018-02-15 06:48:17

标签: numpy

我必须使用大矩阵进行迭代计算: R(t)= M @ R(t-1),其中M是n×n,R是n×1

如果我写这个:

for _ in range(iter_num):
    R = M @ R

我认为它会非常慢,因为每次都必须复制并创建新数组。这有什么方法可以优化这个? (也许在里面做?)

4 个答案:

答案 0 :(得分:3)

有几个时间表明OP的方法实际上很有竞争力:

>>> import functools as ft
>>> kwds = dict(globals=globals(), number=1000)
>>> R = np.random.random((200,))
>>> M = np.random.random((200, 200))
>>> def f_op(M, R):
...     for i in range(k):
...         R = M@R
...     return R
... 
>>> def f_pp(M, R):
...     return ft.reduce(np.matmul, (R,) + k * (M.T,))
... 
>>> def f_ag(M, R):
...     return np.linalg.matrix_power(M, k)@R
... 
>>> def f_tai(M, R):
...     return np.linalg.multi_dot([M]*k+[R])
... 
>>> k = 20
>>> repeat('f_op(M, R)', **kwds)
[0.14156094897771254, 0.1264056910004001, 0.12611976702464744]
>>> repeat('f_pp(M, R)', **kwds)
[0.12594187198556028, 0.1227772050187923, 0.12045996301458217]
>>> repeat('f_ag(M, R)', **kwds)
[2.065609384997515, 2.041590739012463, 2.038702343008481]
>>> repeat('f_tai(M, R)', **kwds)
[3.426795684004901, 3.4321794749994297, 3.4208814119920135]
>>>
>>> k = 500
>>> repeat('f_op(M, R)', **kwds)
[3.066054102004273, 3.0294102499901783, 3.020273027010262]
>>> repeat('f_pp(M, R)', **kwds)
[2.891954762977548, 2.8680382019956596, 2.8558325179910753]
>>> repeat('f_ag(M, R)', **kwds)
[5.216210452985251, 5.1636185249954, 5.157578871003352]

答案 1 :(得分:2)

使用numpy.linalg.multi_dot

np.linalg.multi_dot([M] * iter_num + [R]) 

[M] * iter_num创建对M的引用列表。)

文档中提到了一些想法,

  

(multi_dot)在单个函数调用中计算两个或多个数组的点积,同时自动选择最快的求值顺序。

  

将multi_dot视为:

     

def multi_dot(arrays):返回functools.reduce(np.dot,arrays)

注意OP的方法实际上非常快。请参阅Paul Panzer关于更多时间结果的答案。

感谢Paul Panzer建议使用引用而不是查看。

答案 2 :(得分:1)

这对你有用吗?

R_final = np.linalg.matrix_power(M, iter_num) @ R

好像你正在做M @ M @ M @ ... @ M @ R,它可以被投射到M ** iter_num @ R

答案 3 :(得分:1)

如果iter_numn相比较大(假设np.lialg.matrix_power尚未执行此操作)且M是可逆的,则使用显式频谱分解会很有用:

def mat_pow(a, p):
    vals, vecs = np.linalg.eig(a)
    return vecs @ np.diag(vals**p) @ vecs.T

mat_pow(M, iter_num) @ R

如果M是对称的,您可以使用更快np.linalg.eigh

相关问题