每行中的特定列乘以-1

时间:2019-07-24 19:44:14

标签: pytorch

我有一个大小为(n,m)的pytorch张量A和大小为n的索引列表,这样每个项0 <= index [i]

 A = torch.tensor([[1,2,3],[4,5,6]])
 indices = torch.tensor([1, 2])

 #desired result
 A = [[1,-2,3],[4,5,-6]]

1 个答案:

答案 0 :(得分:0)

当然可以,fancy indexing是必经之路:

import torch

A = torch.Tensor([[1, 2, 3], [4, 5, 6]])
indices = torch.LongTensor([1, 2])

A[range(A.shape[0]), indices] *= -1

记住索引必须为torch.LongTensor类型。如果您使用float成员函数拥有.long(),则可以进行投射。