我有一个PyTorch计算图,它包含一个执行某些计算的子图,然后将这个计算结果(让我们称之为x
)分支到另外两个子图中。这两个子图中的每一个都会产生一些标量结果(让我们称之为y1
和y2
)。我想为这两个结果中的每一个做一个向后传递(也就是说,我想积累两个子图的渐变。我不想执行实际的优化步骤。)
现在,由于内存是一个问题,我想按以下顺序执行操作:
首先,计算x
。然后,计算y1
,并执行y1.backward()
while(这是关键点)保留导致x
的图表,但从x
释放图表到y1
。然后,计算y2
,然后执行y2.backward()
。
换句话说,为了在不牺牲太多速度的情况下节省内存,我想保留x
而不需要重新计算它,但我想放弃从x
到{{1}的所有计算在我不再需要它们之后。
问题是函数y1
的参数retain_graph
将保留导致backward()
的整个图,而我只需要保留图的一部分导致y1
1}}。
这是我理想的想法的例子:
x
如何做到这一点?
答案 0 :(得分:2)
参数import torch
w = torch.tensor(1.0)
w.requires_grad_(True)
# sub-graph for calculating `x`
x = w+10
# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.
# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
将保留整个图形,而不仅仅是子图形。但是,我们可以使用垃圾收集来释放图中不需要的部分。通过删除从retain_graph
到x
的子图的所有引用,此子图将被释放:
y1