使用pytorch

时间:2019-01-03 19:44:02

标签: pytorch

我有一个3维的numpy数组,例如:

 x = [[[0.3, 0.2, 0.5],
       [0.1, 0.2, 0.7],
       [0.2, 0.2, 0.6]]]

indexs数组也是3维的,例如:

indices = [[[0],
            [1],
            [2]]]

我希望输出是:

 output= [[[0.3],
           [0.2],
           [0.6]]]

我尝试了torch.index_select和torch.gather函数,但是找不到正确的方法来处理尺寸。感谢您的帮助!

2 个答案:

答案 0 :(得分:1)

如何使用x.gather(dim=2, indices)?这对我行得通。

答案 1 :(得分:0)

我找到了答案。请让我知道是否有更好的解决方案。

torch.cat([torch.index_select(a.view(1, -1), 1, i.view(1, -1)[0]) 
                                         for a, i in zip(x, indices)])