我是PyTorch的新手,在运行神经网络时遇到“索引张量必须具有与输入张量相同的维数”错误。发生在我调用torch.gather()的实例时。有人可以帮助我理解torch.gather()并解释此错误的原因吗?
以下是发生错误的代码:
def learn(batch, optim, net, target_net, gamma, global_step, target_update):
my_loss = []
optim.zero_grad()
state, action, next_state, reward, done, next_action = batch
qval = net(state.float())
loss_a = torch.gather(qval, 3, action.view(-1,1,1,1)).squeeze() #Error happens here!
loss_b = reward + gamma * torch.max(target_net(next_state.float()).cuda(), dim=3).values * (1 - done.int())
loss_val = torch.sum(( torch.abs(loss_a-loss_b) ))
loss_val /= 128
my_loss.append(loss_val.item())
loss_val.backward()
optim.step()
if global_step % target_update == 0:
target_network.load_state_dict(q_network.state_dict())
如果有帮助,这里是批处理功能,用于创建操作来源的批处理:
def sample_batch(memory,batch_size):
indices = np.random.randint(0,len(memory), (batch_size,))
state = torch.stack([memory[i][0] for i in indices])
action = torch.tensor([memory[i][1] for i in indices], dtype = torch.long)
next_state = torch.stack([memory[i][2] for i in indices])
reward = torch.tensor([memory[i][3] for i in indices], dtype = torch.float)
done = torch.tensor([memory[i][4] for i in indices], dtype = torch.float)
next_action = torch.tensor([memory[i][5] for i in indices], dtype = torch.long)
return state,action,next_state,reward,done,next_action
当我打印出不同形状的'qvals','action'和'action.view(-1,1,1,1)'时,这是输出:
qval torch.Size([10, 225])
act view torch.Size([10, 1, 1, 1])
action shape torch.Size([10])
感谢您对造成此错误的原因进行了任何解释!我想了解更多代码中的情况以及如何解决该问题。谢谢!
答案 0 :(得分:0)
Torch.gather被描述为here。如果我们采用您的代码,则此行
torch.gather(qval, 3, action.view(-1,1,1,1))
等同于
act_view = action.view(10,1,1,1)
out = torch.zeros_like(act_view)
for i in range(10):
for j in range(1):
for k in range(1):
for p in range(1):
out[i,j,k,p] = qval[i,j,k, act_view[i,j,k,p]]
return out
这显然使很少有意义。特别地,qval
不是4-D,因此无法像这样索引。 for
循环的数量由输入张量的形状决定,并且它们都应具有相同的维数才能起作用(这就是您的错误告诉您的)。在这里,qval
是2D,act_view
是4D。
我不确定您要怎么做,但是如果您可以解释您的目标并删除示例中所有无用的内容(主要是与培训和反向传播相关的代码),以获得最小的可重复示例,那么我可以帮助您进一步找到正确的方法:)