此4D einsum操作的张量是多少?

时间:2018-12-11 17:49:53

标签: python numpy tensordot

这是一个简单的代码,将4D矩阵a与3D矩阵b进行“批处理”:

from functools import reduce
import numpy as np
from operator import mul

def einsum(a, b):
    return np.einsum('ijkl,jkl->ikl', a, b)

def original(a, b):
    s0, s1, s2, s3 = a.shape
    c = np.empty((s0, s2, s3))
    for j in range(s3):
        for i in range(s2):
            c[:, j, i] = np.dot(a[:, :, j, i], b[:, j, i])
    return c

sz_a = (16, 4, 512, 512)
sz_b = (4, 512, 512)

a = np.random.random(reduce(mul, sz_a)).reshape(sz_a)
b = np.random.random(reduce(mul, sz_b)).reshape(sz_b)

有关计时:

%timeit original(a, b)
395 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit einsum(a, b)
23.1 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

我想测试一下tensordot的性能,以了解它的比较情况,但是在围绕如何在这里使用它方面确实存在一些麻烦。如果有人熟悉我的指导,将不胜感激。谢谢!

我最初的想法是:

np.tensordot(a, b, axes=((1),(0)))

但是这给了我一个MemoryError,所以我认为那是不对的...

1 个答案:

答案 0 :(得分:0)

您的einsum与等效的matmul的时间比较:

In [910]: timeit (a.transpose(2,3,0,1)@b[:,None].transpose(2,3,0,1)).transpose(2,3,0,1)[:
     ...: ,0]
90.5 ms ± 92.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [911]: timeit np.einsum('ijkl,jkl->ikl', a, b)
92.7 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

时间太短了,我怀疑einsum优化实际上是在使用matmul。最初einsum使用其自己的编译的乘积和迭代,但是最近进行了最近的更改,它使用了多种方法,包括dotmatmul(如果适用)。

matmul的创建是为了处理初始尺寸代表一堆矩阵的情况。在您的问题中,最后2个维度是该堆栈,其中dot作用于首字母。创建matmul来处理这种堆积的点。 dot及其派生词tensordot无法处理这种堆叠。