执行批处理矩阵-pytorch中的多个权重矩阵相乘

时间:2018-09-17 06:15:27

标签: python pytorch

我有一批大小为torch.Size([batch_size, 9, 5])的矩阵 A 和权重矩阵 B 大小为torch.Size([3, 5, 6])。在Keras中,简单的K.dot(A, B)能够处理矩阵乘法以提供大小为(batch_size, 9, 3, 6)的输出。在这里,将 A 中的每一行与 B 中的3个矩阵相乘,以形成(3x6)矩阵。

您如何在割炬中执行类似的操作。根据文档,torch.bmm要求 A B 必须具有相同的批处理大小,因此我尝试这样做:

B = B.unsqueeze(0).repeat((batch_size, 1, 1, 1))
B.size() # torch.Size([batch_size, 3, 5, 6])
torch.bmm(A,B) # gives an error
  

RuntimeError:参数2无效:预期的3D张量,得到4D

好吧,这是预期的错误,但是我该如何执行此操作?

1 个答案:

答案 0 :(得分:0)

您可以使用einstein notation将您要执行的操作描述为bxy,iyk->bxik。因此,您可以使用einsum进行计算。

torch.einsum('bxy,iyk->bxik', (A, B))将为您提供所需的答案。