在 Pytorch 中只计算前端网络的梯度

时间:2021-04-12 10:01:39

标签: pytorch

我有一个非常简单的问题。

假设我有两个网络要训练(即 net1、net2)。 net1 的输出将在训练时输入 net2。 就我而言,我只想更新 net1:

optimizer=Optimizer(net1.parameters(), **kwargs)
loss=net2(net1(x))
loss.backward()
optimizer.step()

虽然这将实现我的目标,但它占用了过多的冗余内存,因为这将计算 net2 的梯度(导致 OOM 错误)。 因此,我尝试了多种尝试来解决此问题:

  1. torch.no_grad:
z=net1(x)
with torch.no_grad():
    loss=net2(z)

没有提高 OOM,但删除了所有梯度,包括来自 net1 的梯度。

  1. requires_grad=False:
net2.requires_grad=False
loss=net2(net1(x))

提高了 OOM。

  1. 分离():
z=net1(x)
loss=net2(z).detach()

没有提高 OOM,但删除了所有梯度,包括来自 net1 的梯度。

  1. eval():
net2.eval()
loss=net2(net1(x))

提高了 OOM。

有没有什么方法可以只计算前端网络(net1)的梯度以提高内存效率? 任何建议将不胜感激。

1 个答案:

答案 0 :(得分:1)

首先让我们试着理解为什么你的方法不起作用。

  1. 此上下文管理器禁用所有梯度计算。
  2. 由于 net1 需要渐变,因此忽略后续 requires_grad=False
  3. 如果你在那个状态分离,这意味着梯度计算已经停止了
  4. Eval 只是将 net2 设置为 eval 模式,根本不影响梯度计算。

根据您的架构,OOM 错误可能来自保存计算图中的所有中间值(通常是 CNN 中的一个问题),也可能来自必须存储梯度(在全连接网络中更常见)。

您可能正在寻找所谓的“检查点”,您甚至不必自己实现它,您可以使用 pytorch 的检查点 API,查看 documentation

这基本上可以让您分别计算和处理 net1net2 的梯度。请注意,您确实需要所有梯度信息通过 net2,否则您无法计算梯度wrt。 net1