计算python中大矩阵的踪迹

时间:2018-11-05 19:27:41

标签: python numpy

我有一个矩阵X,我需要编写一个函数,该函数计算矩阵enter image description here的踪迹。

我写了下一个脚本:

import numpy as np
def test(matrix):
    return (np.dot(matrix, matrix.T)).trace()

np.random.seed(42)
matrix = np.random.uniform(size=(1000, 1))

print(test(matrix))

它在小矩阵上工作正常,但是当我尝试在大矩阵上进行计算时(例如,在形状为(50000, 1)的矩阵上),它给了我一个内存错误。

我试图在网站上的其他问题中找到解决问题的方法,但没有任何帮助。我将不胜感激!

1 个答案:

答案 0 :(得分:4)

您要计算的数字只是X的所有条目的平方和。对平方求和,而不是计算一个充满不需要的条目的巨型矩阵乘积:

return (X**2).sum()

或者散乱矩阵并使用dot,对于连续的X来说可能更快:

raveled = X.ravel()
return raveled.dot(raveled)

实际上,ravel对于不连续的X也可能会更快-即使ravel需要复制,它也不会比{ {1}}。