我正在尝试尽可能快地计算许多3x1向量对的交叉积。此
n = 10000
a = np.random.rand(n, 3)
b = np.random.rand(n, 3)
numpy.cross(a, b)
给出了正确的答案,但是由this answer to a similar question推动,我认为einsum
会让我到达某个地方。我发现了两个
eijk = np.zeros((3, 3, 3))
eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1
np.einsum('ijk,aj,ak->ai', eijk, a, b)
np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
计算交叉产品,但它们的表现令人失望:两种方法的表现都比np.cross
差得多:
%timeit np.cross(a, b)
1000 loops, best of 3: 628 µs per loop
%timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
100 loops, best of 3: 9.02 ms per loop
%timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
100 loops, best of 3: 10.6 ms per loop
有关如何改进einsum
的任何想法?
答案 0 :(得分:2)
您可以使用np.tensordot
引入矩阵乘法,以丢失第一级维度之一,然后使用np.einsum
丢失其他维度,如此 -
np.einsum('aik,ak->ai',np.tensordot(a,eijk,axes=([1],[1])),b)
或者,我们可以使用a
在b
和np.einsum
之间执行广播的元素乘法,然后使用np.tensordot
一次性丢失两个维度,如此 -
np.tensordot(np.einsum('aj,ak->ajk', a, b),eijk,axes=([1,2],[1,2]))
我们可以通过引入新的轴来执行元素乘法,例如a[...,None]*b[:,None]
,但它似乎会减慢它。
尽管如此,这些方法比仅基于np.einsum
的提议方法有了很好的改进,但未能超越np.cross
。
运行时测试 -
In [26]: # Setup input arrays
...: n = 10000
...: a = np.random.rand(n, 3)
...: b = np.random.rand(n, 3)
...:
In [27]: # Time already posted approaches
...: %timeit np.cross(a, b)
...: %timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
...: %timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
...:
1000 loops, best of 3: 298 µs per loop
100 loops, best of 3: 5.29 ms per loop
100 loops, best of 3: 9 ms per loop
In [28]: %timeit np.einsum('aik,ak->ai',np.tensordot(a,eijk,axes=([1],[1])),b)
1000 loops, best of 3: 838 µs per loop
In [30]: %timeit np.tensordot(np.einsum('aj,ak->ajk',a,b),eijk,axes=([1,2],[1,2]))
1000 loops, best of 3: 882 µs per loop
答案 1 :(得分:2)
einsum()
的乘法运算次数多于cross()
,而在最新的NumPy版本中,cross()
不会创建许多临时数组。因此einsum()
不能比cross()
快。
以下是旧的交叉代码:
x = a[1]*b[2] - a[2]*b[1]
y = a[2]*b[0] - a[0]*b[2]
z = a[0]*b[1] - a[1]*b[0]
以下是新的交叉代码:
multiply(a1, b2, out=cp0)
tmp = array(a2 * b1)
cp0 -= tmp
multiply(a2, b0, out=cp1)
multiply(a0, b2, out=tmp)
cp1 -= tmp
multiply(a0, b1, out=cp2)
multiply(a1, b0, out=tmp)
cp2 -= tmp
要加速它,你需要cython或numba。