我有一批大小为torch.Size([batch_size, 9, 5])
的矩阵 A 和权重矩阵 B 大小为torch.Size([3, 5, 6])
。在Keras中,简单的K.dot(A, B)
能够处理矩阵乘法以提供大小为(batch_size, 9, 3, 6)
的输出。在这里,将 A 中的每一行与 B 中的3个矩阵相乘,以形成(3x6)
矩阵。
您如何在割炬中执行类似的操作。根据文档,torch.bmm
要求 A 和 B 必须具有相同的批处理大小,因此我尝试这样做:
B = B.unsqueeze(0).repeat((batch_size, 1, 1, 1))
B.size() # torch.Size([batch_size, 3, 5, 6])
torch.bmm(A,B) # gives an error
RuntimeError:参数2无效:预期的3D张量,得到4D
好吧,这是预期的错误,但是我该如何执行此操作?
答案 0 :(得分:0)
您可以使用einstein notation将您要执行的操作描述为bxy,iyk->bxik
。因此,您可以使用einsum
进行计算。
torch.einsum('bxy,iyk->bxik', (A, B))
将为您提供所需的答案。