我有一个大小为(n,m)的pytorch张量A和大小为n的索引列表,这样每个项0 <= index [i] A = torch.tensor([[1,2,3],[4,5,6]])
indices = torch.tensor([1, 2])
#desired result
A = [[1,-2,3],[4,5,-6]]
答案 0 :(得分:0)
当然可以,fancy indexing是必经之路:
import torch
A = torch.Tensor([[1, 2, 3], [4, 5, 6]])
indices = torch.LongTensor([1, 2])
A[range(A.shape[0]), indices] *= -1
记住索引必须为torch.LongTensor
类型。如果您使用float
成员函数拥有.long()
,则可以进行投射。