PyTorch:当使用backward()时,我如何只保留图的一部分?

时间:2018-06-07 12:33:14

标签: pytorch

我有一个PyTorch计算图,它包含一个执行某些计算的子图,然后将这个计算结果(让我们称之为x)分支到另外两个子图中。这两个子图中的每一个都会产生一些标量结果(让我们称之为y1y2)。我想为这两个结果中的每一个做一个向后传递(也就是说,我想积累两个子图的渐变。我不想执行实际的优化步骤。)

现在,由于内存是一个问题,我想按以下顺序执行操作: 首先,计算x。然后,计算y1,并执行y1.backward() while(这是关键点)保留导致x的图表,但从x释放图表到y1 。然后,计算y2,然后执行y2.backward()

换句话说,为了在不牺牲太多速度的情况下节省内存,我想保留x而不需要重新计算它,但我想放弃从x到{{1}的所有计算在我不再需要它们之后。

问题是函数y1的参数retain_graph将保留导致backward()的整个图,而我只需要保留图的一部分导致y1 1}}。

这是我理想的想法的例子:

x

如何做到这一点?

1 个答案:

答案 0 :(得分:2)

参数import torch w = torch.tensor(1.0) w.requires_grad_(True) # sub-graph for calculating `x` x = w+10 # sub-graph for calculating `y1` x1 = x*x y1 = x1*x1 y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it. # sub-graph for calculating `y2` x2 = torch.sqrt(x) y2 = x2/2 y2.backward() 将保留整个图形,而不仅仅是子图形。但是,我们可以使用垃圾收集来释放图中不需要的部分。通过删除从retain_graphx的子图的所有引用,此子图将被释放:

y1