Pytorch keep_graph =是真的张量还是要从函数图中分离张量?

时间:2019-03-24 16:10:27

标签: python gradient pytorch

所以我在Pytorch中使用函数图遇到了很多麻烦,我不确定如何解决它。

错误: RuntimeError:试图第二次遍历该图,但是缓冲区已被释放。第一次回叫时,请指定keep_graph = True。

我开始理解,发生这种情况的原因是我的张量没有变化,required_grad=True并且当我再次尝试通过它进行反向传播时,功能图从最后的反向传播仍然存在(它是否正确?)。

据我所知,我对此问题有两种解决方案:要么启用retain_graph = True,以使函数图保持不变,要么我必须将张量与函数图分离,然后稍后再附加它们。我不确定该如何解决。任何帮助表示赞赏!

    for index in indeces:
        if i >= 6:
            return
        z = self.trajectory.outcome_values[index]
        v = self.trajectory.net_outcome_values[index]
        prob = self.trajectory.mc_probs[index]

        net_prob = self.trajectory.net_probs[index]
        log_prob = torch.log(net_prob)
        lossfn = ((z-v).pow(2) - prob*log_prob).mean() 
        lossfn.backward()
        self.optimizer.step()
        i +=1

它应该学习6个小批处理,其中的索引来自BatchSampler(SubsetRandomSampler)。这对于第一次更新非常有效,但是在第一次更新之后,循环中的下一步会产生上述错误。

“我的弹道”是一门课程,我在其中保存在仿真过程中获得的所有信息。张量 net_prob net_outcome_values 具有required_grad=True,而 outcome_values mc_probs 具有required_grad=False。我不确定是否必须detach()的净值才能使其正常工作,还是应该保留函数图。

0 个答案:

没有答案