您好,我正在重写在Pytorch中编写的模块,我需要在Numpy中进行所有代码计算,而我正在努力在Pytorch中翻译这一行feature.addmm_(1, -2, feat, feat.t())
在Numpy中相当于Pytorch addmm_
或addmm
例如:
a = torch.tensor([[0.9619, 0.0384, 0.7012],
[0.5561, 0.3637, 0.9272]])
b = 2
c = a.sum(dim=1, keepdim=True).expand(b, b)
y = c + c.t()
y = y.addmm_(1, -2, a, a.t())
print(y)
# tensor([[0.5662, 1.1504], [1.1504, 1.0916]])