pytorch中2d张量的高级索引

时间:2017-12-28 11:36:23

标签: multidimensional-array indexing pytorch tensor numpy-broadcasting

我有一个2d张量X.和两个索引列表是第一个索引,第二个是a和b。我想做

X[a[i],b[i]] = 0 for i in range(len(a))

我该怎么做?如果我直接执行X[a,b]错误是IndexError:无法广播高级索引对象

1 个答案:

答案 0 :(得分:2)

检查包含索引的lists,某些值可能超出范围。当您获得 IndexError 时,如下所示:

  

在[43]中:X [4,4]

     

IndexError Traceback(最近一次调用最后一次)    in()   ----> 1 X [4,4]

     

IndexError:索引4超出了维度0(大小为3)的范围

如果您的指数范围正确,它应该可以正常工作。

以下是一个例子:

In [35]: X = torch.Tensor([[3, 4, 5, 6], [1, 2, 3, 4], [6, 3, 2, 1]])

In [36]: X
Out[36]: 

 3  4  5  6
 1  2  3  4
 6  3  2  1
[torch.FloatTensor of size 3x4]

In [37]: a = [0, 2]

In [38]: b = [1, 2]

In [39]: X[a, b]
Out[39]: 

 4
 2
[torch.FloatTensor of size 2]

In [40]: X[a, b] = 0

In [41]: X
Out[41]: 

 3  0  5  6
 1  2  3  4
 6  3  0  1
[torch.FloatTensor of size 3x4]