pytorch RuntimeError:预期类型为torch.cuda.LongTensor的对象,但发现类型为torch.LongTensor的参数

时间:2019-01-28 10:55:49

标签: pytorch

我正在实施dqn进行强化学习,但出现以下错误:

  

proj_dist.view(-1).index_add_(0,(l + offset).view(-1),(next_dist *   (u.float()-b))。view(-1))

     

RuntimeError:类型为torch.cuda.LongTensor的预期对象,但找到了   输入torch.LongTensor作为参数#3'other'

我已经尝试将对index_add_的第一次调用的第三个参数修改为:

  • .long().cuda()
  • long().to(device) ...

没有任何作用

这是代码,其中函数的参数均为张量。请注意,我使用的是64个张量,因此next_states将是一些64张量:

def projection_distribution(self, next_states, rewards, dones):
        batch_size  = next_states.size(0)

        delta_z = float(self.V_max - self.V_min) / (self.num_atoms - 1)
        support = torch.linspace(self.V_min, self.V_max, self.num_atoms)

        # torch.Size([64, 4, 51])
        next_dist   = self.qnetwork_target(next_states).data.cpu() * 
                       support.data.cpu()
        # torch.Size([64])
        next_actions = next_dist.sum(2).max(1)[1]
        # torch.Size([64, 1, 51])
        next_actions = 
        next_actions.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, 
        next_dist.size(2))
        # these are the batch_size distributions relative to the optimal 
        actions 'next_actions'
        # torch.Size([64, 51])
        next_dist   = next_dist.gather(1, next_actions).squeeze(1)

        rewards = rewards.expand_as(next_dist)
        print('rewards: ', rewards.size())
        dones   = dones.expand_as(next_dist)
        support = support.unsqueeze(0).expand_as(next_dist)
        print('support: ', support.size())

        Tz = rewards + (1 - dones) * GAMMA * support.to(device)
        Tz = Tz.clamp(min=self.V_min, max=self.V_max)
        b  = (Tz - self.V_min) / delta_z
        l  = b.floor().long()
        u  = b.ceil().long()
        print('TYPE: ' , b.type(), l.type(), u.type(), next_dist.type())

        offset = torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size).long()\
                        .unsqueeze(1).expand(batch_size, self.num_atoms)
        print('offset: ', offset.size())

        proj_dist = torch.zeros(next_dist.size())  
        print('proj_dist_1: ', proj_dist.size())
        proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
        print('proj_dist_2: ', proj_dist.size())
        proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
        print('proj_dist_3: ', proj_dist.size())

        return proj_dist

为什么会出现该错误?我如何解决它?

非常感谢 最高

0 个答案:

没有答案