生成三个覆盖所有二维矩阵的非重叠蒙版

时间:2018-09-16 17:21:09

标签: python arrays numpy

我有一个2-d数组,我想通过掩码生成将其分为3个非重叠和随机子矩阵。例如,我有一个如下的矩阵:

input = [[1,2,3],
         [4,5,6],
         [7,8,9]]

我想要三个随机的零一掩码,如下所示:

mask1 = [[0,1,0],
        [1,0,1],
        [0,0,0]]
mask2 = [[1,0,0],
         [0,1,0],
         [1,0,0]]
mask3 =[[0,0,1],
        [0,0,0],
        [0,1,1]]

但是我的输入矩阵太大,我需要以快速的方式进行处理。我还想确定每个蒙版作为输入的比例。在以上示例中,所有蒙版的比率均相等。 为了产生一个随机的蒙版,我使用以下代码:

np.random.choice([0, 1],size=(size of matrix[0],size of matrix[1]))

我的问题是如何制作不重叠的蒙版。

1 个答案:

答案 0 :(得分:2)

IIUC,您可以创建0、1和2的随机矩阵,然后提取m == 0,m == 1和m == 2的值:

groups = np.random.randint(0, 3, (5,5))
masks = (groups[...,None] == np.arange(3)[None,:]).T

但是,这不能保证每个蒙版中元素的数量相等。为此,您可以平衡分配:

a = np.arange(25).reshape(5,5)  # dummy input
groups = np.random.permutation(np.arange(a.size) % 3).reshape(a.shape)
masks = (groups[...,None] == np.arange(3)[None,:]).T

如果您希望将随机概率归为一组:

groups = np.random.choice([0,1,2], p=[0.3, 0.6, 0.1], size=a.shape)

之类的。您需要做的就是确定如何为groups分配单元格,然后可以构建掩码。

例如:

In [431]: groups = np.random.permutation(np.arange(a.size) % 3).reshape(a.shape)

In [432]: groups
Out[432]: 
array([[1, 0, 0, 2, 0],
       [1, 2, 0, 0, 1],
       [2, 0, 2, 0, 2],
       [1, 1, 2, 1, 0],
       [2, 2, 1, 1, 0]], dtype=int32)

In [433]: masks = (groups[...,None] == np.arange(3)[None,:]).T

In [434]: masks
Out[434]: 
array([[[False, False, False, False, False],
        [ True, False,  True, False, False],
        [ True,  True, False, False, False],
        [False,  True,  True, False, False],
        [ True, False, False,  True,  True]],

       [[ True,  True, False,  True, False],
        [False, False, False,  True, False],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False,  True, False, False, False]],

       [[False, False,  True, False,  True],
        [False,  True, False, False,  True],
        [False, False,  True,  True, False],
        [ True, False, False, False, False],
        [False, False,  True, False, False]]])

这给了我一个完整的面具:

In [450]: masks.sum(axis=0)
Out[450]: 
array([[1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1]])

并保持合理平衡。如果单元格的数目是3的倍数,则这些数目都将一致。

In [451]: masks.sum(2).sum(1)
Out[451]: array([9, 8, 8])

如果愿意,可以使用.astype(int)从布尔数组转换为0和1的整数数组。