我有一个非常简单的问题。
假设我有两个网络要训练(即 net1、net2)。 net1 的输出将在训练时输入 net2。 就我而言,我只想更新 net1:
optimizer=Optimizer(net1.parameters(), **kwargs)
loss=net2(net1(x))
loss.backward()
optimizer.step()
虽然这将实现我的目标,但它占用了过多的冗余内存,因为这将计算 net2 的梯度(导致 OOM 错误)。 因此,我尝试了多种尝试来解决此问题:
z=net1(x)
with torch.no_grad():
loss=net2(z)
没有提高 OOM,但删除了所有梯度,包括来自 net1 的梯度。
net2.requires_grad=False
loss=net2(net1(x))
提高了 OOM。
z=net1(x)
loss=net2(z).detach()
没有提高 OOM,但删除了所有梯度,包括来自 net1 的梯度。
net2.eval()
loss=net2(net1(x))
提高了 OOM。
有没有什么方法可以只计算前端网络(net1)的梯度以提高内存效率? 任何建议将不胜感激。
答案 0 :(得分:1)
首先让我们试着理解为什么你的方法不起作用。
net1
需要渐变,因此忽略后续 requires_grad=False
。根据您的架构,OOM 错误可能来自保存计算图中的所有中间值(通常是 CNN 中的一个问题),也可能来自必须存储梯度(在全连接网络中更常见)。
您可能正在寻找所谓的“检查点”,您甚至不必自己实现它,您可以使用 pytorch 的检查点 API,查看 documentation。
这基本上可以让您分别计算和处理 net1
和 net2
的梯度。请注意,您确实需要所有梯度信息通过 net2
,否则您无法计算梯度wrt。 net1
!