我有一个3维的numpy数组,例如:
x = [[[0.3, 0.2, 0.5],
[0.1, 0.2, 0.7],
[0.2, 0.2, 0.6]]]
indexs数组也是3维的,例如:
indices = [[[0],
[1],
[2]]]
我希望输出是:
output= [[[0.3],
[0.2],
[0.6]]]
我尝试了torch.index_select和torch.gather函数,但是找不到正确的方法来处理尺寸。感谢您的帮助!
答案 0 :(得分:1)
如何使用x.gather(dim=2, indices)
?这对我行得通。
答案 1 :(得分:0)
我找到了答案。请让我知道是否有更好的解决方案。
torch.cat([torch.index_select(a.view(1, -1), 1, i.view(1, -1)[0])
for a, i in zip(x, indices)])