X是n x p矩阵,其中p远大于n。假设n = 1000且p = 500000.当我跑:
X = np.random.randn(1000,500000)
S = X.dot(X.T)
执行此操作最终会占用大量内存,尽管结果大小为1000 x 1000.一旦操作完成,内存使用会恢复。有没有办法解决这个问题?
答案 0 :(得分:6)
问题不在于X
和X.T
是相同内存空间本身的视图,
而是X.T
是F-连续的而不是C-连续的。当然,这必须
对于该情况中的至少一个输入阵列必然是正确的
你将数组与其转置视图相乘的地方。
numpy< 1.8,np.dot
会
创建任何 F-ordered输入数组的C-ordered副本,而不仅仅是碰巧在同一块上的视图
存储器中。
例如:
X = np.random.randn(1000,50000)
Y = np.random.randn(50000, 100)
# X and Y are both C-order, no copy
%memit np.dot(X, Y)
# maximum of 1: 485.554688 MB per loop
# make X Fortran order and Y C-order, now the larger array (X) gets
# copied
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 867.070312 MB per loop
# make X C-order and Y Fortran order, now the smaller array (Y) gets
# copied
X = np.ascontiguousarray(X)
Y = np.asfortranarray(Y)
%memit np.dot(X, Y)
# maximum of 1: 523.792969 MB per loop
# make both of them F-ordered, both get copied!
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 905.093750 MB per loop
如果复制是一个问题(例如当X
非常大时),您可以采取哪些措施?
最好的选择可能是升级到更新版本的numpy - 正如@perimosocordiae所指出的,这个性能问题已在this pull request中得到解决。
如果出于某种原因无法升级numpy,还有一个技巧可以让你通过scipy.linalg.blas
直接调用相关的BLAS函数来执行基于BLAS的快速点积而无需强制复制(无耻地)从this answer被盗:
from scipy.linalg import blas
X = np.random.randn(1000,50000)
%memit res1 = np.dot(X, X.T)
# maximum of 1: 845.367188 MB per loop
%memit res2 = blas.dgemm(alpha=1., a=X.T, b=X.T, trans_a=True)
# maximum of 1: 471.656250 MB per loop
print np.all(res1 == res2)
# True