减少广播和总和的内存使用量

时间:2019-03-30 16:56:08

标签: python numpy

如何用更少的内存执行以下操作?

a * b使用的内存比必要的多9倍。

是否可以将np.sum(a * b, axis=3)替换为np.tensordot

谢谢。

import numpy as np
x = np.random.choice(100, size=(23, 10, 3))
a = x[:, :, np.newaxis, :]
b = x[:, np.newaxis, :, :]
y = np.sum(a * b, axis=3)

1 个答案:

答案 0 :(得分:2)

<button onclick="clickHandler()">Click here!</button>
<span id="res"></span>

我不确定您怎么算这9次。

这也可以用In [749]: x = np.random.choice(100, size=(23, 10, 3)) ...: a = x[:, :, np.newaxis, :] ...: b = x[:, np.newaxis, :, :] ...: y = np.sum(a * b, axis=3) In [750]: a.shape Out[750]: (23, 10, 1, 3) # a view, no extra memory In [751]: b.shape Out[751]: (23, 1, 10, 3) In [752]: y.shape Out[752]: (23, 10, 10) In [753]: (a*b).shape Out[753]: (23, 10, 10, 3) # 3x larger than y 表示:

einsum

我不确定它的内存使用情况如何比较。原始形式在“ ijkl”空间上进行迭代。

速度更快:

In [758]: np.einsum('ijl,ikl->ijk', x, x).shape                                 
Out[758]: (23, 10, 10)
In [759]: np.allclose(np.einsum('ijl,ikl->ijk', x, x),y)                        
Out[759]: True

花了一些工夫,但我发现使用In [760]: timeit np.einsum('ijl,ikl->ijk', x, x).shape 74.1 µs ± 256 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) In [761]: timeit y = np.sum(a * b, axis=3) 90.9 µs ± 86.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 的方法更快:

matmul

这会将更多工作移至快速编译的库。我不能说存储器的用途。

在重复项中找到的更简单的解决方案要快一些:

In [771]: (a@b.transpose(0,1,3,2)).shape                                        
Out[771]: (23, 10, 1, 10)
In [772]: np.allclose((a@b.transpose(0,1,3,2)).squeeze(),y)                     
Out[772]: True
In [773]: timeit (a@b.transpose(0,1,3,2)).shape                                 
20 µs ± 28 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)