pytorch梯度尚未计算

时间:2019-08-19 06:55:21

标签: python pytorch

我创建一个NN。我在重新计算渐变时遇到问题。问题是我将2个张量u @ v标量相乘并将它们之一标准化。重要的是不能为h计算梯度。因此,我使用detach()。此外,在重新计算梯度期间,不应考虑归一化(我不知道该怎么做)。

import torch
from torch import nn


class Nn(nn.Module):
    def __init__(self):
        super(Nn, self).__init__()
        self.ln = nn.Linear(5, 5)

    def forward(self, x):
        v = self.ln(x)

        u = v.clone()
        h = v.clone()

        u /= u.norm()
        h = h.detach()
        h /= h.norm()

        res = torch.stack([torch.stack([u @ h, u @ h])])

        return res


def patches_generator():
    while True:
        decoder = torch.rand((5, ))
        target = torch.randint(2, (1,))
        yield decoder, target


net = Nn()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

net.train()
torch.autograd.set_detect_anomaly(True)
for decoder, targets in patches_generator():
    optimizer.zero_grad()
    outputs = net(decoder)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

结果,我收到以下错误:

  

RuntimeError:梯度计算所需的变量之一具有   通过就地操作进行了修改:[torch.FloatTensor [9,512,1,   1]],其版本为ReluBackward1的输出0,版本3;预期   版本2代替。提示:上方的回溯显示   无法计算其梯度的操作。中的变量   问题在之后或之后的任何地方已更改。祝你好运!

1 个答案:

答案 0 :(得分:1)

问题是此行中的就地除法运算符应用于u

u /= u.norm()

将其更改为

u = u / u.norm()

使代码运行。原因是就地运算符会覆盖此行的中间结果

u = v.clone()

这使得Pytorch无法计算梯度。

(问题中的错误消息包含对ReluBackward1层的引用,不在简化代码示例中。Pytorch ReLU层具有可选的in_place自变量,该变量使操作这通常是可行的,因为在顺序网络中,无需区分ReLU激活的输出和权重的输出以计算梯度,但是在更复杂的体系结构中,可能有必要保留输出权重。)