如何找到就地操作:就地操作已修改了梯度计算所需的变量之一

时间:2019-07-16 08:57:56

标签: python pytorch

我遇到RuntimeError:当我运行以下代码时,就位操作已修改了梯度计算所需的变量之一。我尝试寻找一些解决此问题的方法,但未能解决我的问题。

此代码在Pytorch 0.4上运行

import torch
import torch.nn as nn
from torch.nn import functional as F
import pdb

class Summ(nn.Module):
    def __init__(self, in_dim=1024, hid_dim=256, num_layers=1, cell='lstm', action='wo', semantic=False):
        super(DSN, self).__init__()
        assert cell in ['lstm', 'gru'], "cell must be either 'lstm' or 'gru'"
        if cell == 'lstm':
            self.rnn = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
        else:
            self.rnn = nn.GRU(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
        if semantic:
            self.lstm_video = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=False, batch_first=True)
            self.lstm_summary = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=False, batch_first=True)

        self.fc = nn.Linear(hid_dim*2, 1)
        self.fc1 = nn.Linear(in_dim, 1)
        self.fc2 = nn.Linear(in_dim, 1)
        self.fc3 = nn.Linear(in_dim, 1)
        self.action = action
        self.semantic = semantic

    def forward(self, x):
        # x:1* T*1024
        if self.semantic:
            ori = x.clone()
            h_video, _ = self.lstm_video(x) # embedding of video

        T = x.size()[1]
        D = x.size()[2]
        d1 = torch.zeros(T,D,requires_grad=True)
        d2 = torch.zeros(T,D,requires_grad=True)
        d3 = torch.zeros(T,D,requires_grad=True)
        for i in range(T-1):
            d1[i,:] = abs(x[0,i+1,:] - x[0,i,:])

        for i in range(T-2):
            d2[i,:] = abs(x[0,i+2,:] - x[0,i,:])

        for i in range(T-4):
            d3[i,:] = abs(x[0,i+4,:] - x[0,i,:])

        f1 = self.fc1(d1.cuda())
        f2 = self.fc2(d2.cuda())
        f3 = self.fc3(d3.cuda())

        if self.action in ['softmax', 'softmax_skip']:
            f = F.softmax(f1 + f2 + f3, dim=0)
        else:
            f = f1 + f2 + f3

        for i in range(T):
            with torch.no_grad():
                if self.action == 'softmax_skip':
                    x[0,i,:] = x[0,i,:].clone() * (1 + f[i,0])
                else:
                    x[0,i,:] = x[0,i,:].clone() * f[i,0]
        h, _ = self.rnn(x)
        p = torch.sigmoid(self.fc(h))

        if self.semantic:
            for i in range(T):
                with torch.no_grad():
                    ori[0,i,:] = ori[0,i,:].clone() * p[0,i,0]
            h_summary, _ = self.lstm_summary(ori) # embedding of summary

            return p, h_video, h_summary

        return p

def semantic_loss(hv, hs):
    h_vd = torch.squeeze(hv)
    h_sm = torch.squeeze(hs)

    gap_h = h_vd - h_sm
    loss = torch.norm(gap_h)
    return loss

0 个答案:

没有答案