我遇到RuntimeError:当我运行以下代码时,就位操作已修改了梯度计算所需的变量之一。我尝试寻找一些解决此问题的方法,但未能解决我的问题。
此代码在Pytorch 0.4上运行
import torch
import torch.nn as nn
from torch.nn import functional as F
import pdb
class Summ(nn.Module):
def __init__(self, in_dim=1024, hid_dim=256, num_layers=1, cell='lstm', action='wo', semantic=False):
super(DSN, self).__init__()
assert cell in ['lstm', 'gru'], "cell must be either 'lstm' or 'gru'"
if cell == 'lstm':
self.rnn = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
else:
self.rnn = nn.GRU(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
if semantic:
self.lstm_video = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=False, batch_first=True)
self.lstm_summary = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=False, batch_first=True)
self.fc = nn.Linear(hid_dim*2, 1)
self.fc1 = nn.Linear(in_dim, 1)
self.fc2 = nn.Linear(in_dim, 1)
self.fc3 = nn.Linear(in_dim, 1)
self.action = action
self.semantic = semantic
def forward(self, x):
# x:1* T*1024
if self.semantic:
ori = x.clone()
h_video, _ = self.lstm_video(x) # embedding of video
T = x.size()[1]
D = x.size()[2]
d1 = torch.zeros(T,D,requires_grad=True)
d2 = torch.zeros(T,D,requires_grad=True)
d3 = torch.zeros(T,D,requires_grad=True)
for i in range(T-1):
d1[i,:] = abs(x[0,i+1,:] - x[0,i,:])
for i in range(T-2):
d2[i,:] = abs(x[0,i+2,:] - x[0,i,:])
for i in range(T-4):
d3[i,:] = abs(x[0,i+4,:] - x[0,i,:])
f1 = self.fc1(d1.cuda())
f2 = self.fc2(d2.cuda())
f3 = self.fc3(d3.cuda())
if self.action in ['softmax', 'softmax_skip']:
f = F.softmax(f1 + f2 + f3, dim=0)
else:
f = f1 + f2 + f3
for i in range(T):
with torch.no_grad():
if self.action == 'softmax_skip':
x[0,i,:] = x[0,i,:].clone() * (1 + f[i,0])
else:
x[0,i,:] = x[0,i,:].clone() * f[i,0]
h, _ = self.rnn(x)
p = torch.sigmoid(self.fc(h))
if self.semantic:
for i in range(T):
with torch.no_grad():
ori[0,i,:] = ori[0,i,:].clone() * p[0,i,0]
h_summary, _ = self.lstm_summary(ori) # embedding of summary
return p, h_video, h_summary
return p
def semantic_loss(hv, hs):
h_vd = torch.squeeze(hv)
h_sm = torch.squeeze(hs)
gap_h = h_vd - h_sm
loss = torch.norm(gap_h)
return loss