我有两个a(16,8,8,64)
和b(64,64)
形状的张量。假设我将a
的最后一个维度提取到另一个列向量c
中,我想计算matmul(matmul(c.T, b), c)
。我希望在a
的前3个维度中都这样做。那就是最终产品的形状应为(16,8,8,1)
。如何在pytorch中实现此目标?
答案 0 :(得分:0)
可以执行以下操作:
row_vec = a[:, :, :, None, :].float()
col_vec = a[:, :, :, :, None].float()
b = (b[None, None, None, :, :]).float()
prod = torch.matmul(torch.matmul(row_vec, b), col_vec)