如何在索引处添加到pytorch张量?

时间:2019-09-19 13:49:49

标签: pytorch

我必须承认,我对分散*和索引*操作有些困惑-我不确定它们中的任何一个都可以完全满足我的要求,这很简单:

给出一些二维张量

z = tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])

以及二维索引列表(或张量?):

inds = tensor([[0, 0],
               [1, 1],
               [1, 2]])

我想在这些索引处向z添加标量(并有效地做到这一点):

znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
               [1., 4., 4., 1.],
               [1., 1., 1., 1.]])

如果必须的话,我可以使该标量成为任何形状的张量(其中所有元素= 3),但是我宁愿不要...

2 个答案:

答案 0 :(得分:1)

您必须为索引提供两个列表。第一个具有行位置,第二个具有列位置。在您的示例中,它将是:

z[[0, 1, 1], [0, 1, 2]] += 3

torch.Tumsor索引在Numpy之后。有关更多详细信息,请参见https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing

答案 1 :(得分:1)

此代码实现了您想要的:

z_new = z.clone() # copy the tensor
z_new[inds[:, 0], inds[:, 1]] += 3 # modify selected indices of new tensor

在PyTorch中,您可以将张量的每个轴与另一个张量索引。