Pytorch:使用索引列表访问子张量

时间:2020-06-04 16:55:24

标签: python pytorch tensor matrix-indexing tensor-indexing

我有一对张量ST,尺寸为(s1,...,sm)(t1,...,tn),且si < ti。我想在T的每个维度中指定索引列表,以“嵌入” S中的T。如果I1s1(0,1,...,t1)索引的列表,并且同样适用于I2In的索引,我想做类似的事情 T.select(I1,...,In)=S 这将导致现在T在索引S上具有等于(I1,...,In)条目的条目。 例如

`S=
[[1,1],
[1,1]]

T=
[[0,0,0],
[0,0,0],
[0,0,0]]

T.select([0,2],[0,2])=S

T=
[[1,0,1],
[0,0,0],
[1,0,1]]`

1 个答案:

答案 0 :(得分:1)

如果您灵活地仅对索引部分使用NumPy ,那么这是一种方法,方法是使用numpy.ix_()构造一个开放式网格并使用该网格来填充来自张量S。如果不可接受,则可以使用torch.meshgrid()

下面是两种方法的说明,说明中散布着注释。

# input tensors to work with
In [174]: T 
Out[174]: 
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])

# I'm using unique tensor just for clarity; But any tensor should work.
In [175]: S 
Out[175]: 
tensor([[10, 11],
        [12, 13]])

# indices where we want the values from `S` to be filled in, along both dimensions
In [176]: idxs = [[0,2], [0,2]] 

现在,我们将通过传递索引来利用np.ix_()torch.meshgrid()来生成开放网格:

# mesh using `np.ix_`
In [177]: mesh = np.ix_(*idxs)

# as an alternative, we can use `torch.meshgrid()` 
In [191]: mesh = torch.meshgrid([torch.tensor(lst) for lst in idxs])

# replace the values from tensor `S` using basic indexing
In [178]: T[mesh] = S 

# sanity check!
In [179]: T 
Out[179]:
tensor([[10,  0, 11],
        [ 0,  0,  0],
        [12,  0, 13]])