我在google collab中遇到此错误。我尝试了其他数据类型,例如布尔张量,但没有用,请帮助
代码
def _mask(prev_generated_seq):
prev_mask = torch.eq(prev_generated_seq, 1)
lengths = torch.argmax(prev_mask,dim=1)
#test = torch.max(prev_mask,dim=1)
#lengths = torch.FloatTensor(test)
max_len = prev_generated_seq.size(1)
mask = []
for i in range(prev_generated_seq.size(0)):
if lengths[i] == 0:
mask_line = [0] * max_len
else:
mask_line = [0] * lengths[i].item()
mask_line.extend([1] * (max_len - lengths[i].item()))
mask.append(mask_line)
mask = torch.ByteTensor(mask)
if args.cuda:
mask = mask.cuda()
return prev_generated_seq.data.masked_fill_(mask, 0)
错误
File "main.py", line 179, in <module>
train_epoches(abstracts, model, config.epochs, teacher_forcing_ratio=1)
File "main.py", line 155, in train_epoches
target_variables, model, teacher_forcing_ratio)
File "main.py", line 139, in train_batch
prev_generated_seq = _mask(prev_generated_seq)
File "main.py", line 101, in _mask
lengths = torch.argmax(prev_mask,dim=1)
RuntimeError: "argmax_cuda" not implemented for 'Bool'