如何最有效地计算矩阵乘积的对角线

时间:2020-05-08 20:52:10

标签: python numpy optimization pytorch linear-algebra

我要计算以下内容:

import numpy as np
n= 3
m = 2
x = np.random.randn(n,m)

#Method 1
y = np.zeros(m)
for i in range(m):
    y[i] = x[:,i] @ x[:,i]

#Method 2
y2 = np.diag(x.T @ x)

第一种方法有一个问题,它使用for循环,效率不是很高(我需要在GPU上的py​​torch中这样做数百万次)

当我只需要对角线条目时,第二种方法将计算全矩阵乘积,因此也不是很有效。

我想知道是否存在任何聪明的方法?

1 个答案:

答案 0 :(得分:2)

使用人工构造的和积。您需要各个列的平方和:

y = (x * x).sum(axis=0)

正如Divakar建议的那样,np.einsum可能会提供较少的内存密集型选项,因为它不需要临时数组x * x

y = np.einsum('ij,ij->j', x, x)
相关问题