Pytorch 损失不为零但梯度为零

时间:2021-03-31 02:25:19

标签: pytorch

我想得到网络中每个子模块的损失,所以每个子模块输出它的损失值。

下面是我的代码。

import torch
import torch.nn as nn
from torch.optim import SGD

class m1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,1,3)
        self.conv2 = nn.Conv2d(1,1,3)
        self.loss = nn.SmoothL1Loss()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        
        return dict(m1loss=self.loss(x, x + 1.1)), x


class m2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,1,3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(1,1,3)
        self.loss = nn.SmoothL1Loss()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return dict(m2loss=self.loss(x, x + 2.31)), x
    
    
class m(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,1,3)
        self.conv2 = nn.Conv2d(1,1,1)
        self.relu = nn.ReLU()
        self.m1 = m1()
        self.m2 = m2()
        self.losses = dict()
        self.loss = nn.SmoothL1Loss()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        loss, x = self.m1(x)
        self.losses.update(loss)
        loss, x = self.m2(x)
        self.losses.update(loss)
        self.losses.update(mloss=self.loss(x, x + 1.21))
        return self.losses
    
net = m()
sgd = SGD(net.parameters(), lr=10)

x = torch.rand((1,1,24,24))
out = net(x)
sgd.zero_grad()
l = sum(v for k,v in out.items())
print('Loss: ', l)
l.backward()
sgd.step()

for p in net.parameters():
    print('grad: ', p.grad)

我检查了网络的权重是否更新,但根本没有更新。

当我打印损失值和所有梯度时, 损失值=3.12,但所有梯度都为零。

怎么了?

Loss:  tensor(3.1200, grad_fn=<AddBackward0>)
grad:  tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
grad:  tensor([0.])
grad:  tensor([[[[0.]]]])
grad:  tensor([0.])
grad:  tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
grad:  tensor([0.])
grad:  tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
grad:  tensor([0.])
grad:  tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
grad:  tensor([0.])
grad:  tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])
grad:  tensor([0.])

0 个答案:

没有答案
相关问题