分别给定形状为R
和S
的NumPy数组(m, d)
和(m, n, d)
,我想计算形状P
的数组(m, n)
其(i, j)
个条目为np.dot(R[i, :] , S[i, j, :])
。
执行双重for循环不需要任何额外的空间(除了m * n
的{{1}}空间),但不会节省时间。
使用广播,我可以P
,但这会花费额外的P = np.sum(R[:, np.newaxis, :] * S, axis=2)
空间。
这样做的时间和空间效率最高的是什么?
答案 0 :(得分:4)
在这些情况下,最好考虑numba
,它可以提供两全其美的优势:
import numpy as np
from numba import jit
def vanilla_mult(R, S):
m, n = R.shape[0], S.shape[1]
result = np.empty((m, n), dtype=R.dtype)
for i in range(m):
for j in range(n):
result[i, j] = np.dot(R[i, :], S[i, j,:])
return result
def broadcast_mult(R, S):
return np.sum(R[:, np.newaxis, :] * S, axis=2)
@jit(nopython=True)
def jit_mult(R, S):
m, n = R.shape[0], S.shape[1]
result = np.empty((m, n), dtype=R.dtype)
for i in range(m):
for j in range(n):
result[i, j] = np.dot(R[i, :], S[i, j,:])
return result
注意,vanilla_mult
和jit_mult
具有完全相同的实现,但后者是即时编译的。我们来测试一下:
In [1]: import test # the above is in test.py
In [2]: import numpy as np
In [3]: m, n, d = 100, 100, 100
In [4]: R = np.random.rand(m, d)
In [5]: S = np.random.rand(m, n, d)
行...
In [6]: %timeit test.broadcast_mult(R, S)
100 loops, best of 3: 1.95 ms per loop
In [7]: %timeit test.vanilla_mult(R, S)
100 loops, best of 3: 11.7 ms per loop
哎呀,是的,与广播相比,计算时间增加了近5倍。然而...
In [8]: %timeit test.jit_mult(R, S)
The slowest run took 760.57 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 870 µs per loop
尼斯!我们可以通过简单的JITing将运行时间缩短一半!这如何扩展?
In [12]: m, n, d = 1000, 1000, 100
In [13]: R = np.random.rand(m, d)
In [14]: S = np.random.rand(m, n, d)
In [15]: %timeit test.vanilla_mult(R, S)
1 loop, best of 3: 1.22 s per loop
In [16]: %timeit test.broadcast_mult(R, S)
1 loop, best of 3: 666 ms per loop
In [17]: %timeit test.jit_mult(R, S)
The slowest run took 7.59 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 83.6 ms per loop
很好地扩展非常,因为必须创建大型中间阵列才开始阻止广播,与vanilla方法相比,它只有一半的时间,但它需要几乎7倍的时间和JIT方法一样多!
最后,我们比较np.einsum
方法:
In [19]: %timeit np.einsum('md,mnd->mn', R, S)
10 loops, best of 3: 59.5 ms per loop
显然它是速度的赢家。不过,我对它的评论空间要求不够熟悉。
答案 1 :(得分:4)
einsum
是另一个常见的嫌疑人
m, n, d = 100, 100, 100
>>> R = np.random.random((m, d))
>>> S = np.random.random((m, n, d))
>>> np.einsum('md,mnd->mn', R, S)
>>> np.allclose(np.einsum('md,mnd->mn', R, S), (R[:,None,:]*S).sum(axis=-1))
True
>>> from timeit import repeat
>>> repeat('np.einsum("md,mnd->mn", R, S)', globals=globals(), number=1000)
[0.7004671019967645, 0.6925274690147489, 0.6952172230230644]
>>> repeat('(R[:,None,:]*S).sum(axis=-1)', globals=globals(), number=1000)
[3.0512512560235336, 3.0466731210472062, 3.044075728044845]
有些间接证据表明einsum
对RAM不太浪费:
>>> m, n, d = 1000, 1001, 1002
>>> # Too much for broadcasting:
>>> np.zeros((m, n, d))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
MemoryError
>>> R = np.random.random((m, d))
>>> S = np.random.random((n, d))
>>> np.einsum('md,nd->mn', R, S).shape
(1000, 1001)