我正在尝试计算网络的雅可比损失(即执行双重反向传播),并且出现以下错误: RuntimeError:梯度计算所需的变量之一已通过就地操作进行了修改
我在代码中找不到就地操作,所以我不知道要修复哪一行。
*错误发生在最后一行: loss3.backward()
inputs_reg = Variable(data, requires_grad=True)
output_reg = self.model.forward(inputs_reg)
num_classes = output.size()[1]
jacobian_list = []
grad_output = torch.zeros(*output_reg.size())
if inputs_reg.is_cuda:
grad_output = grad_output.cuda()
jacobian_list = jacobian.cuda()
for i in range(10):
zero_gradients(inputs_reg)
grad_output.zero_()
grad_output[:, i] = 1
jacobian_list.append(torch.autograd.grad(outputs=output_reg,
inputs=inputs_reg,
grad_outputs=grad_output,
only_inputs=True,
retain_graph=True,
create_graph=True)[0])
jacobian = torch.stack(jacobian_list, dim=0)
loss3 = jacobian.norm()
loss3.backward()
答案 0 :(得分:0)
public enum Week {
...
public Week getWeek(String key) {
... logic lookup
... obtain a Week(weekResult) with that key
return weekResult;
}
}
就位,grad_output.zero_()
也就位。就地意味着“修改张量而不是返回已应用修改的新张量”。并非就地的示例解决方案是torch.where
。将第一列清零的示例
grad_output[:, i-1] = 0
请注意import torch
t = torch.randn(3, 3)
ixs = torch.arange(3, dtype=torch.int64)
zeroed = torch.where(ixs[None, :] == 1, torch.tensor(0.), t)
zeroed
tensor([[-0.6616, 0.0000, 0.7329],
[ 0.8961, 0.0000, -0.1978],
[ 0.0798, 0.0000, -1.2041]])
t
tensor([[-0.6616, -1.6422, 0.7329],
[ 0.8961, -0.9623, -0.1978],
[ 0.0798, -0.7733, -1.2041]])
如何保留以前的值,而t
如何保留您想要的值。
答案 1 :(得分:0)
谢谢! 我将grad_output中有问题的代码替换为
inputs_reg = Variable(data, requires_grad=True)
output_reg = self.model.forward(inputs_reg)
num_classes = output.size()[1]
jacobian_list = []
grad_output = torch.zeros(*output_reg.size())
if inputs_reg.is_cuda:
grad_output = grad_output.cuda()
for i in range(5):
zero_gradients(inputs_reg)
grad_output_curr = grad_output.clone()
grad_output_curr[:, i] = 1
jacobian_list.append(torch.autograd.grad(outputs=output_reg,
inputs=inputs_reg,
grad_outputs=grad_output_curr,
only_inputs=True,
retain_graph=True,
create_graph=True)[0])
jacobian = torch.stack(jacobian_list, dim=0)
loss3 = jacobian.norm()
loss3.backward()
答案 2 :(得分:0)
您可以使用set_detect_anomaly
function available in autograd
包来确切地找到引起错误的行。
这里是link,描述了相同的问题以及使用上述功能的解决方案。