有没有一种在Pytorch中创建随机位掩码的有效方法?

时间:2018-03-11 04:24:03

标签: python pytorch

我希望有一个随机位掩码,它具有0 s的指定百分比。我设计的功能是:

def create_mask(shape, rate):
    """
    The idea is, you take a random permutations of numbers. You then mod then
    mod it by the [number of entries in the bitmask] / [percent of 0s you
    want]. The number of zeros will be exactly the rate of zeros need. You
    can clamp the values for a bitmask.
    """
    mask = torch.randperm(reduce(operator.mul, shape, 1)).float().cuda()
    # Mod it by the percent to get an even dist of 0s.
    mask = torch.fmod(mask, reduce(operator.mul, shape, 1) / rate)
    # Anything not zero should be put to 1
    mask = torch.clamp(mask, 0, 1)
    return mask.view(shape)

举例说明:

>>> x = create_mask((10, 10), 10)
>>> x

    1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     0     1     1     1
    0     1     1     1     1     0     1     1     1     1
    0     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     0
    1     1     1     1     1     1     1     1     1     1
    1     1     1     0     1     1     1     0     1     1
    0     1     1     1     1     1     1     1     1     1
    1     1     1     0     1     1     0     1     1     1
    1     1     1     1     1     1     1     1     1     1
[torch.cuda.FloatTensor of size 10x10 (GPU 0)]

我对此方法的主要问题是需要rate来划分shape。我想要一个接受任意小数的函数,并在位掩码中给出大约rate%的0。此外,我试图找到一种相对有效的方法。因此,我宁愿不将numpy数组从CPU移动到GPU。是否有一种有效的方法可以允许小数rate

3 个答案:

答案 0 :(得分:9)

对于遇到这种情况的任何人,这将直接在GPU上创建一个具有大约80%零的位掩码。 (PyTorch 0.3)

torch.cuda.FloatTensor(10, 10).uniform_() > 0.8

答案 1 :(得分:0)

使用NumPy和cudamat:

import numpy as np
import cudamat as cm

gpuMask = cm.CUDAMatrix(np.random.choice([0, 1], size=(10,10), p=[1./2, 1./2]))

其中列表的元素是1和0概率的分数表示。

答案 2 :(得分:0)

使用 Pytorch 直接在 GPU 上创建位掩码的正确方法是:

import torch

tensor = torch.randn((3, 5), device=torch.device("cuda")) < 0.9

# tensor([[ True,  True, False,  True,  True,  True,  True,  True,  True, False],
#         [ True,  True,  True,  True,  True,  True,  True, False, False,  True],
#         [ True, False, False,  True,  True,  True,  True,  True,  True, False],
#         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
#         [ True,  True, False,  True,  True,  True,  True, False,  True,  True],
#         [ True,  True, False, False,  True,  True,  True, False,  True,  True],
#         [ True,  True,  True, False,  True,  True,  True,  True,  True,  True],
#         [ True,  True,  True,  True,  True,  True, False, False,  True,  True],
#         [ True, False,  True,  True,  True,  True,  True,  True,  True,  True],
#         [ True,  True,  True,  True,  True,  True,  True,  True, False,  True]],
#        device='cuda:0')