我在火炬中有一个张量 x ,假设形状为(5,3,2,6),而另一个张量 idx 形状为(5,3,2, 1)包含第一个张量中每个元素的索引。我想用第二张量的索引切片第一个张量。我尝试过x = x [idx],但是当我真的希望它的形状为(5,3,2)或(5,3,2,1)时,我得到了一个奇怪的尺寸。
我将尝试举一个简单的例子: 假设
x=torch.Tensor([[10,20,30],
[8,4,43]])
idx = torch.Tensor([[0],
[2]])
我想要类似的东西
y = x[idx]
使得'y'输出[[10],[43]]
或类似的东西。
索引表示所需元素最后一个维度的位置。对于上面的示例,其中x.shape =(2,3)的最后一个维度是列,则“ idx”中的索引是列。我想要这个,但要超过2个维度
答案 0 :(得分:0)
根据我的评论,您需要idx
作为最后一个维度的索引,idx
中的每个索引都对应于x
中的类似索引(最后一个维度除外) )。在这种情况下(这是Numpy版本,您可以将其转换为手电筒):
ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]
输出:
[[10]
[43]]
答案 1 :(得分:0)
您可以使用range
;和squeeze
以获得正确的idx
尺寸,例如
x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])
# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
[43.]])