索引张量必须具有与使用torch.gather()时遇到的输入张量错误相同的维数

时间:2020-08-19 18:45:50

标签: python pytorch

我是PyTorch的新手,在运行神经网络时遇到“索引张量必须具有与输入张量相同的维数”错误。发生在我调用torch.gather()的实例时。有人可以帮助我理解torch.gather()并解释此错误的原因吗?

以下是发生错误的代码:

  def learn(batch, optim, net, target_net, gamma, global_step, target_update):
      my_loss = []
      optim.zero_grad()
  
      state, action, next_state, reward, done, next_action = batch
      qval = net(state.float())
  
      loss_a = torch.gather(qval, 3, action.view(-1,1,1,1)).squeeze() #Error happens here!

      loss_b = reward + gamma * torch.max(target_net(next_state.float()).cuda(), dim=3).values * (1 - done.int())
      loss_val = torch.sum(( torch.abs(loss_a-loss_b) ))
      loss_val /= 128
      my_loss.append(loss_val.item())
      loss_val.backward()
      optim.step()
      if global_step % target_update == 0:
          target_network.load_state_dict(q_network.state_dict())

如果有帮助,这里是批处理功能,用于创建操作来源的批处理:

def sample_batch(memory,batch_size):
    
    indices = np.random.randint(0,len(memory), (batch_size,))

    state = torch.stack([memory[i][0] for i in indices]) 
    action = torch.tensor([memory[i][1] for i in indices], dtype = torch.long)
    next_state = torch.stack([memory[i][2] for i in indices])
    reward = torch.tensor([memory[i][3] for i in indices], dtype = torch.float)
    done = torch.tensor([memory[i][4] for i in indices], dtype = torch.float)
    next_action = torch.tensor([memory[i][5] for i in indices], dtype = torch.long)

    return state,action,next_state,reward,done,next_action

当我打印出不同形状的'qvals','action'和'action.view(-1,1,1,1)'时,这是输出:

qval torch.Size([10, 225])
act view torch.Size([10, 1, 1, 1])
action shape  torch.Size([10])

感谢您对造成此错误的原因进行了任何解释!我想了解更多代码中的情况以及如何解决该问题。谢谢!

1 个答案:

答案 0 :(得分:0)

Torch.gather被描述为here。如果我们采用您的代码,则此行

torch.gather(qval, 3, action.view(-1,1,1,1))

等同于

act_view = action.view(10,1,1,1)
out = torch.zeros_like(act_view)
for i in range(10):
    for j in range(1):
         for k in range(1):
              for p in range(1):
                   out[i,j,k,p] = qval[i,j,k, act_view[i,j,k,p]]
return out

这显然使很少有意义。特别地,qval不是4-D,因此无法像这样索引。 for循环的数量由输入张量的形状决定,并且它们都应具有相同的维数才能起作用(这就是您的错误告诉您的)。在这里,qval是2D,act_view是4D。

我不确定您要怎么做,但是如果您可以解释您的目标并删除示例中所有无用的内容(主要是与培训和反向传播相关的代码),以获得最小的可重复示例,那么我可以帮助您进一步找到正确的方法:)