PyTorch张量沿任意轴的积àNumPy的`tensordot`

时间:2018-07-10 13:25:11

标签: matrix-multiplication pytorch dot-product

NumPy提供了非常有用的tensordot函数。它使您可以沿任意轴(其尺寸匹配)计算两个ndarrays的乘积。我很难在PyTorch中找到任何类似的东西。 mm仅适用于2D阵列,matmul具有某些不良的广播特性。

我想念什么吗?我真的是要重塑阵列以模仿要使用mm的产品吗?

2 个答案:

答案 0 :(得分:3)

原始答案是完全正确的,但作为更新,Pytorch now supports tensordot 本身就是如此。与 numpy 相同的调用签名,但将 axes 更改为 dims

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

答案 1 :(得分:1)

如@McLawrence所述,此功能目前正在讨论中(issue thread)。

同时,您可以考虑使用torch.einsum(),例如:

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)