指定网络以不反向传播给定函数中的渐变

时间:2019-09-03 15:09:42

标签: python pytorch

在我的神经网络中,为了计算损失,我需要在训练过程中进行一些中间计算,以便首先获得一些变换rv

rv = factor.ransac(source, target, prob, device)
predicted = factor.predict(source, rv, outputs, device)
loss = criterion(predicted, target)

我只想通过predicted而不是通过factor.ransac反向传播渐变。我该怎么做?

1 个答案:

答案 0 :(得分:0)

您可以使用detach()

rv = factor.ransac(source, target, prob, device).detach()

这样,rv上的渐变将被丢弃。