如何用2D蒙版遮罩3D张量并保持原始矢量的尺寸?

时间:2020-05-22 14:08:56

标签: pytorch tensor

假设我有一个3D张量A

A = torch.arange(24).view(4, 3, 2)
print(A)

并要求使用2D张量对其进行遮罩

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

使用PyTorch中的masked_select功能会导致以下错误。

torch.masked_select(X, (mask == 1))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
     12 
     13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
     15 #Y = X * mask_
     16 print(Y)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2

如何用2D蒙版遮罩3D张量并保持原始矢量的尺寸?任何提示将不胜感激。

1 个答案:

答案 0 :(得分:0)

本质上,我们需要将张量掩码的尺寸与要掩盖的张量匹配。

有两种方法。

方法1:不保留原始张量尺寸。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)

方法1的输出:

tensor([ 0,  1,  8,  9, 18, 19])

方法2:保留原始张量尺寸(通过填充)。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = X * mask_
print(Y)

方法2的输出:

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
Mask:  tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
        [1, 0, 0]])
tensor([[[1, 1],
         [0, 0],
         [0, 0]],

        [[0, 0],
         [1, 1],
         [0, 0]],

        [[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [0, 0],
         [0, 0]]])
tensor([[[ 0,  1],
         [ 0,  0],
         [ 0,  0]],

        [[ 0,  0],
         [ 8,  9],
         [ 0,  0]],

        [[ 0,  0],
         [ 0,  0],
         [ 0,  0]],

        [[18, 19],
         [ 0,  0],
         [ 0,  0]]]