Tensordot的性能瓶颈

时间:2018-01-31 03:59:34

标签: python performance numpy multidimensional-array dot-product

在我尝试理解numpy.tensordot()时,我尝试了文档中的示例,并确信我们可以通过tensordot参数的不同排列获得完全相同的axes结果。例如,轴的下面两个排列是等价的(即它们都产生相同的结果):

In [28]: a = np.arange(60.).reshape(3,4,5)
In [29]: b = np.arange(24.).reshape(4,3,2)

In [30]: perm1 = np.tensordot(a, b, axes=[(1, 0), (0, 1)])
In [31]: perm2 = np.tensordot(a, b, axes=[(0, 1), (1, 0)])

In [32]: np.all(perm1 == perm2)
Out[32]: True

然而,在测量性能的同时,我发现一个排列比另一个更快超过2倍,这让我很困惑。

# setting up input arrays
In [19]: a = np.arange(30*40*50).reshape(30,40,50)
In [20]: b = np.arange(40*30*20).reshape(40,30,20)

# contracting the first two axes from the input tensors
In [21]: %timeit np.tensordot(a, b, axes=[(0, 1), (1, 0)])
3.23 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# an equivalent way of contraction of the first two
# axes from the input tensors as in the above case
In [22]: %timeit np.tensordot(a, b, axes=[(1, 0), (0, 1)])
1.62 ms ± 16.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

那么,后一种情况下 2x加速的原因是什么?是否与NumPy ndarrays在内存中的内部结构有关?或者是其他东西?提前感谢您的见解!

1 个答案:

答案 0 :(得分:1)

在不详细说明的情况下,这两个计算会重新创建tensordot所采取的操作,并生成相同的perm值。

它们表现出相同的2倍速差:

In [24]: timeit np.dot(a.transpose(2,0,1).reshape(50,-1), b.transpose(1,0,2).reshape(-1,20))
4.39 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [25]: timeit np.dot(a.transpose(2,1,0).reshape(50,-1), b.reshape(-1,20))
2.99 ms ± 97.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我的猜测是,第二个更快,因为b.reshape(-1,20)不需要副本,而转置后跟第一个重塑形式。

计划不同的重塑:

In [28]: timeit a.transpose(2,1,0).reshape(50,-1)
128 µs ± 978 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [29]: timeit a.transpose(2,0,1).reshape(50,-1)
1.04 µs ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [30]: timeit b.reshape(-1,20)
501 ns ± 14.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [31]: timeit b.transpose(1,0,2).reshape(-1,20)
27.5 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

速度存在显着差异。 [30]只是view,因此可以解释为什么它如此之快。我猜[28]速度慢得多,因为它涉及元素的完全反转,其中[29]副本(40,50)块。

相关问题