我是Pytorch的新手,遇到此错误:
x.gather(1,c)
RuntimeError:收集处的索引无效 /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:457
以下是有关张量的一些信息:
print(x.size())
print(c.size())
print(type(x))
print(type(c))
torch.Size([128, 2])
torch.Size([128, 1])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
x用浮点值填充,而c用整数填充,这可能是问题吗?
答案 0 :(得分:0)
这仅表示您的索引张量c
具有无效的索引。
例如,以下索引张量有效:
x = torch.tensor([
[5, 9, 1],
[3, 2, 8],
[7, 4, 0]
])
c = torch.tensor([
[0, 0, 0],
[1, 2, 0],
[2, 2, 1]
])
x.gather(1, c)
>>>tensor([[5, 5, 5],
[2, 8, 3],
[0, 0, 4]])
但是,以下索引张量无效:
c = torch.tensor([
[0, 0, 0],
[1, 2, 0],
[2, 2, 3]
])
它给出了您提到的例外情况
RuntimeError:聚集中的索引无效