给定具有相同形状的数组和蒙版,我想要具有相同形状的蒙版输出,并在蒙版为False时包含0。
例如,
# input array
img = torch.randn(2, 2)
print(img)
# tensor([[0.4684, 0.8316],
# [0.8635, 0.4228]])
print(img.shape)
# torch.Size([2, 2])
# mask
mask = torch.BoolTensor(2, 2)
print(mask)
# tensor([[False, True],
# [ True, True]])
print(mask.shape)
# torch.Size([2, 2])
# expected masked output of shape 2x2
# tensor([[0, 0.8316],
# [0.8635, 0.4228]])
问题:遮罩会更改输出的形状,如下所示:
#1: shape changed
img[mask]
# tensor([0.8316, 0.8635, 0.4228])
答案 0 :(得分:1)
最直接的方法是创建另一个张量来处理它。
import torch
def generate_masked_tensor(input, mask, fill=0):
masked_tensor = torch.zeros(input.size()) + fill
masked_tensor[mask] = input[mask]
return masked_tensor
if __name__ == "__main__":
img = torch.randn(2, 2)
mask = torch.tensor([False, True, True, False]).bool().view(2, 2)
masked_img = generate_masked_tensor(img, mask)
print (masked_img)
输出:
tensor([[0.0000, 0.8028],
[1.5411, 0.0000]])
答案 1 :(得分:1)
只需将布尔型掩码类型转换为整数掩码,然后使用float将掩码转换为与img
中相同的类型。然后执行逐元素乘法。
masked_output = img * mask.int().float()
答案 2 :(得分:1)
我找到解决方法之一:
img[mask==False] = 0
或使用
img[~mask] = 0
它将更改img
本身。