使用类型为torch.LongTensor的对象索引张量

时间:2017-09-21 01:26:52

标签: python pytorch

我是使用Pytorch的新手,我在运行代码时收到此错误:

TypeError:使用类型为torch.LongTensor的对象索引张量。唯一支持的类型是整数,切片,numpy标量和torch.LongTensor或torch.ByteTensor作为唯一的参数。

您能否指出我正确的方向和任何帮助将不胜感激。

if os.path.exists(CHECKPOINT_NAME):
print("=> loading checkpoint '{}'".format(CHECKPOINT_NAME))
checkpoint = torch.load(CHECKPOINT_NAME)
EPOCH = checkpoint['epoch']
BEST_LOSS = checkpoint['best_loss']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
      .format(CHECKPOINT_NAME, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'. Starting from scratch".format(CHECKPOINT_NAME))

for epoch in range(EPOCH, NUM_EPOCHS):
    train(train_dataset_loader, model, loss_fn, optimizer, epoch + 1, val_dataset_loader)
    loss = validate(val_dataset_loader, model, loss_fn)

    if loss < BEST_LOSS:
        print('{} better than previous best loss of {}'.format(loss, BEST_LOSS))
        BEST_LOSS = loss
        is_best = True
    else:
        is_best = False

    save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_loss': BEST_LOSS,
            'optimizer' : optimizer.state_dict(),
        }, is_best
    )

&#13;
&#13;
ypeError                                 Traceback (most recent call last)
<ipython-input-16-4c3a0a33f81b> in <module>()
     12 
     13 for epoch in range(EPOCH, NUM_EPOCHS):
---> 14     train(train_dataset_loader, model, loss_fn, optimizer, epoch + 1, val_dataset_loader)
     15     loss = validate(val_dataset_loader, model, loss_fn)
     16 

<ipython-input-14-13120db09b49> in train(train_loader, model, criterion, optimizer, epoch, val_loader)
     65         # compute output
     66         model.zero_grad()
---> 67         log_probas, indices = model.forward(batch)
     68 
     69         labels = Variable(batch['class'][indices])

<ipython-input-13-f9a47d332f53> in forward(self, batch)
     18         gene = batch['gene'][indices]
     19         variation = batch['variation'][indices]
---> 20         text_batch = torch.stack(batch['text'], 0)[:, indices]
     21 
     22         # Wrap all tensors around a variable. Send to GPU if possible.
&#13;
&#13;
&#13;

1 个答案:

答案 0 :(得分:0)

您的问题仍然存在问题。您尚未共享可以重现错误的完整代码。从错误中可以清楚地看出您的模型的前向功能存在问题。错误发生在下面一行。

library(dplyr)

ggplot(Monthly_BMS_df, aes(time, value, group=brand, colour=brand)) + 
  geom_hline(yintercept=0, colour="grey60") +
  geom_text(data=Monthly_BMS_df %>% filter(time==min(time)),
            aes(label=brand), position=position_nudge(-0.25)) +
  geom_line(linetype="12", alpha=0.5, size=0.7) +
  geom_text(aes(label=value)) +
  guides(colour=FALSE) +
  theme_classic()

当您为没有预期形状的张量索引时会发生错误。因此,使用text_batch = torch.stack(batch['text'], 0)[:, indices] 检查张量形状的形状。如果您需要更多帮助,请按照以下指南改进您的问题。

请阅读并遵循帮助文档中的发布指南。 Minimal, complete, verifiable示例适用于此处。在您发布MCVE代码并准确描述问题之前,我们无法有效地为您提供帮助。我们应该能够将您发布的代码粘贴到文本文件中,并重现您描述的问题。