假设我有一个全零的掩码张量,像这样:
mask = torch.zeros(5,3, dtype=torch.bool)
现在,我想在以下mask
和rows
索引的交点处将cols
的值设置为True
:
rows = torch.tensor([0,2,4])
cols = torch.tensor([1,2])
我想产生以下结果:
tensor([[False, True, True ],
[False, False, False],
[False, True, True ],
[False, False, False],
[False, True, True ]])
当我尝试以下代码时,出现错误:
mask[rows, cols] = True
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
如何在PyTorch中高效地做到这一点?
答案 0 :(得分:1)
您需要合适的形状才能使用torch.unsqueeze
mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False, True, True],
[False, False, False],
[False, True, True],
[False, False, False],
[False, True, True]])
mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False, True, True],
[False, False, False],
[False, True, True],
[False, False, False],
[False, True, True]])