NumPy Einsum的黑巫毒教

时间:2014-11-26 09:43:55

标签: python numpy

我使用einsum函数获得了一些工作代码。但是因为einsum目前仍然像我black voodoo。我想知道,这段代码实际上是做什么的,以及是否可以使用np.dot

以某种方式进行优化

我的数据看起来像这样

n, p, q = 40000, 8, 4
a = np.random.rand(n, p, q)
b = np.random.rand(n, p)

我现有的函数einsum函数看起来像这样

f1 = np.einsum("ijx,ijy->ixy", a, a)
f2 = np.einsum("ijx,ij->ix", a, b)

但它真正做到了什么?我到达此处:每个维度(轴)由标签表示,i等于第一个轴nj表示第二个轴p和{{ 1}}和x是同一轴y的不同标签。 因此q的输出数组的顺序为f1,因此输出形状为ixy

但就我而言。并且

1 个答案:

答案 0 :(得分:1)

让我们玩几个小阵列

In [110]: a=np.arange(2*3*4).reshape(2,3,4)

In [111]: b=np.arange(2*3).reshape(2,3)

In [112]: np.einsum('ijx,ij->ix',a,b)
Out[112]: 
array([[ 20,  23,  26,  29],
       [200, 212, 224, 236]])

In [113]: np.diagonal(np.dot(b,a)).T
Out[113]: 
array([[ 20,  23,  26,  29],
       [200, 212, 224, 236]])

np.dot在第一个数组的最后一个dim上运行,第二个到第二个数组的最后一个运行。所以我必须切换参数,以便3维度排成一行。 dot(b,a)生成一个(2,2,4)数组。 diagonal选择其中2个“行”,然后进行转置以进行清理。另一个einsum很好地表达了清理:

In [122]: np.einsum('iik->ik',np.dot(b,a))

由于np.dot生成的数组比原始einsum更大,因此即使基础C代码更紧密,也不会更快。

(奇怪的是我无法用np.dot(b,a)复制einsum;它不会生成那个(2,2,...)数组。

对于a,a情况,我们必须做类似的事情 - 滚动一个数组的轴,使最后一个维度与另一个数据的第二个到最后一个排列,执行dot,然后清理使用diagonaltranspose

In [157]: np.einsum('ijx,ijy->ixy',a,a).shape
Out[157]: (2, 4, 4)
In [158]: np.einsum('ijjx->jix',np.dot(np.rollaxis(a,2),a))
In [176]: np.diagonal(np.dot(np.rollaxis(a,2),a),0,2).T

tensordot是另一种在所选轴上采用dot的方式。

np.tensordot(a,a,(1,1))
np.diagonal(np.rollaxis(np.tensordot(a,a,(1,1)),1),0,2).T  # with cleanup