如何在某些行和列中更改2D张量的值

时间:2020-04-09 09:08:57

标签: python pytorch

假设我有一个全零的掩码张量,像这样:

mask = torch.zeros(5,3, dtype=torch.bool)

现在,我想在以下maskrows索引的交点处将cols的值设置为True

rows = torch.tensor([0,2,4]) 
cols = torch.tensor([1,2])

我想产生以下结果:

tensor([[False, True,  True ],
        [False, False, False],
        [False, True,  True ],
        [False, False, False],
        [False, True,  True ]])

当我尝试以下代码时,出现错误:

mask[rows, cols] = True

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

如何在PyTorch中高效地做到这一点?

1 个答案:

答案 0 :(得分:1)

您需要合适的形状才能使用torch.unsqueeze

mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

torch.reshape

mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])