我想得到网络中每个子模块的损失,所以每个子模块输出它的损失值。
下面是我的代码。
import torch
import torch.nn as nn
from torch.optim import SGD
class m1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,1,3)
self.conv2 = nn.Conv2d(1,1,3)
self.loss = nn.SmoothL1Loss()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return dict(m1loss=self.loss(x, x + 1.1)), x
class m2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,1,3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(1,1,3)
self.loss = nn.SmoothL1Loss()
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
return dict(m2loss=self.loss(x, x + 2.31)), x
class m(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,1,3)
self.conv2 = nn.Conv2d(1,1,1)
self.relu = nn.ReLU()
self.m1 = m1()
self.m2 = m2()
self.losses = dict()
self.loss = nn.SmoothL1Loss()
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
loss, x = self.m1(x)
self.losses.update(loss)
loss, x = self.m2(x)
self.losses.update(loss)
self.losses.update(mloss=self.loss(x, x + 1.21))
return self.losses
net = m()
sgd = SGD(net.parameters(), lr=10)
x = torch.rand((1,1,24,24))
out = net(x)
sgd.zero_grad()
l = sum(v for k,v in out.items())
print('Loss: ', l)
l.backward()
sgd.step()
for p in net.parameters():
print('grad: ', p.grad)
我检查了网络的权重是否更新,但根本没有更新。
当我打印损失值和所有梯度时, 损失值=3.12,但所有梯度都为零。
怎么了?
Loss: tensor(3.1200, grad_fn=<AddBackward0>)
grad: tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
grad: tensor([0.])
grad: tensor([[[[0.]]]])
grad: tensor([0.])
grad: tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
grad: tensor([0.])
grad: tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
grad: tensor([0.])
grad: tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
grad: tensor([0.])
grad: tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
grad: tensor([0.])