3,4轴火炬的矩阵乘法

时间:2019-12-20 14:56:39

标签: python pytorch matrix-multiplication

我有两个a(16,8,8,64)b(64,64)形状的张量。假设我将a的最后一个维度提取到另一个列向量c中,我想计算matmul(matmul(c.T, b), c)。我希望在a的前3个维度中都这样做。那就是最终产品的形状应为(16,8,8,1)。如何在pytorch中实现此目标?

1 个答案:

答案 0 :(得分:0)

可以执行以下操作:

row_vec = a[:, :, :, None, :].float()
col_vec = a[:, :, :, :, None].float()
b = (b[None, None, None, :, :]).float()
prod = torch.matmul(torch.matmul(row_vec, b), col_vec)