假设我有一个3D张量A
A = torch.arange(24).view(4, 3, 2)
print(A)
并要求使用2D张量对其进行遮罩
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
使用PyTorch中的masked_select功能会导致以下错误。
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
如何用2D蒙版遮罩3D张量并保持原始矢量的尺寸?任何提示将不胜感激。
答案 0 :(得分:0)
本质上,我们需要将张量掩码的尺寸与要掩盖的张量匹配。
有两种方法。
方法1:不保留原始张量尺寸。
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
方法1的输出:
tensor([ 0, 1, 8, 9, 18, 19])
方法2:保留原始张量尺寸(通过填充)。
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = X * mask_
print(Y)
方法2的输出:
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
Mask: tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 0],
[1, 0, 0]])
tensor([[[1, 1],
[0, 0],
[0, 0]],
[[0, 0],
[1, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[1, 1],
[0, 0],
[0, 0]]])
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]]