为了平衡向后计算损失,我不仅要在forward()中计算损失,而且要向后实现损失。这是我的实现方式,有什么不对吗?
def forward(self, x, y=None, criterion=None, gpu_nums=1):
x = self.task(x)
if not y == None:
loss = criterion(x, y)
if gpu_nums > 1:
loss /= gpu_nums
loss.backward()
return x, loss
return x
我的实现是否正确?如果不正确,希望您能提供正确的方法。这对每个人都有意义。