NumPy提供了非常有用的tensordot
函数。它使您可以沿任意轴(其尺寸匹配)计算两个ndarrays
的乘积。我很难在PyTorch中找到任何类似的东西。 mm
仅适用于2D阵列,matmul
具有某些不良的广播特性。
我想念什么吗?我真的是要重塑阵列以模仿要使用mm
的产品吗?
答案 0 :(得分:3)
原始答案是完全正确的,但作为更新,Pytorch now supports tensordot
本身就是如此。与 numpy 相同的调用签名,但将 axes
更改为 dims
。
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
答案 1 :(得分:1)
如@McLawrence所述,此功能目前正在讨论中(issue thread)。
同时,您可以考虑使用torch.einsum()
,例如:
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)