如何获得特定维度上的张量的MSE?

时间:2020-04-01 17:24:45

标签: python pytorch tensor mean-square-error

我有2个张量,其中.size中的torch.Size([2272, 161])。我想得到它们之间的均方误差。但是,我希望沿着161个通道中的每个通道使用,因此我的错误张量的.sizetorch.Size([161])。我该怎么做?

似乎torch.nn.MSELoss不允许我指定尺寸。

1 个答案:

答案 0 :(得分:3)

对于nn.MSELoss,您可以指定选项reduction='none'。然后,这将为您返回两个张量的每个入口位置的平方误差。然后,您可以应用torch.sum / torch.mean。

a = torch.randn(2272,161)
b = torch.randn(2272,161)
loss = nn.MSELoss(reduction='none')
loss_result = torch.sum(loss(a,b),dim=0) 

我认为没有直接的方法可以在损失初始化时指定将均值/总和应用于哪个维度。希望有帮助!