我有以下代码:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
我有一个多维索引b
,想用它来选择a
中的单个单元格。如果b不是张量,我可以做:
a[1,1,1,1]
哪个返回正确的单元格,但是:
a[b]
不起作用,因为它只选择了a[1]
四次。
我该怎么做?谢谢
答案 0 :(得分:4)
您可以使用chunk
将b
分成4个,然后使用成块的b
来索引所需的特定元素:
>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)] # here's the trick!
Out[24]: tensor([[40, 80, 0]])
关于它的好处是可以轻松地将其推广到a
的任何尺寸,您只需要使卡盘的数量等于a
的尺寸即可。
答案 1 :(得分:4)
更优雅(更简单)的解决方案可能是将b
强制转换为元组:
a[tuple(b)]
Out[10]: tensor(5.)
我很想知道它如何与“常规” numpy一起工作,并找到了相关的文章对此进行了很好的解释here。