我是使用Pytorch的新手,我在运行代码时收到此错误:
TypeError:使用类型为torch.LongTensor的对象索引张量。唯一支持的类型是整数,切片,numpy标量和torch.LongTensor或torch.ByteTensor作为唯一的参数。
您能否指出我正确的方向和任何帮助将不胜感激。
if os.path.exists(CHECKPOINT_NAME):
print("=> loading checkpoint '{}'".format(CHECKPOINT_NAME))
checkpoint = torch.load(CHECKPOINT_NAME)
EPOCH = checkpoint['epoch']
BEST_LOSS = checkpoint['best_loss']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(CHECKPOINT_NAME, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'. Starting from scratch".format(CHECKPOINT_NAME))
for epoch in range(EPOCH, NUM_EPOCHS):
train(train_dataset_loader, model, loss_fn, optimizer, epoch + 1, val_dataset_loader)
loss = validate(val_dataset_loader, model, loss_fn)
if loss < BEST_LOSS:
print('{} better than previous best loss of {}'.format(loss, BEST_LOSS))
BEST_LOSS = loss
is_best = True
else:
is_best = False
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_loss': BEST_LOSS,
'optimizer' : optimizer.state_dict(),
}, is_best
)
ypeError Traceback (most recent call last)
<ipython-input-16-4c3a0a33f81b> in <module>()
12
13 for epoch in range(EPOCH, NUM_EPOCHS):
---> 14 train(train_dataset_loader, model, loss_fn, optimizer, epoch + 1, val_dataset_loader)
15 loss = validate(val_dataset_loader, model, loss_fn)
16
<ipython-input-14-13120db09b49> in train(train_loader, model, criterion, optimizer, epoch, val_loader)
65 # compute output
66 model.zero_grad()
---> 67 log_probas, indices = model.forward(batch)
68
69 labels = Variable(batch['class'][indices])
<ipython-input-13-f9a47d332f53> in forward(self, batch)
18 gene = batch['gene'][indices]
19 variation = batch['variation'][indices]
---> 20 text_batch = torch.stack(batch['text'], 0)[:, indices]
21
22 # Wrap all tensors around a variable. Send to GPU if possible.
&#13;
答案 0 :(得分:0)
您的问题仍然存在问题。您尚未共享可以重现错误的完整代码。从错误中可以清楚地看出您的模型的前向功能存在问题。错误发生在下面一行。
library(dplyr)
ggplot(Monthly_BMS_df, aes(time, value, group=brand, colour=brand)) +
geom_hline(yintercept=0, colour="grey60") +
geom_text(data=Monthly_BMS_df %>% filter(time==min(time)),
aes(label=brand), position=position_nudge(-0.25)) +
geom_line(linetype="12", alpha=0.5, size=0.7) +
geom_text(aes(label=value)) +
guides(colour=FALSE) +
theme_classic()
当您为没有预期形状的张量索引时会发生错误。因此,使用text_batch = torch.stack(batch['text'], 0)[:, indices]
检查张量形状的形状。如果您需要更多帮助,请按照以下指南改进您的问题。
请阅读并遵循帮助文档中的发布指南。 Minimal, complete, verifiable示例适用于此处。在您发布MCVE代码并准确描述问题之前,我们无法有效地为您提供帮助。我们应该能够将您发布的代码粘贴到文本文件中,并重现您描述的问题。