如何从单行的numpy数组中拆分多个蒙版?

时间:2019-12-25 14:59:57

标签: python numpy

我有一个500,500,3形的numpy数组,其中包含重复的整数值。我可以使用np.unique来获取所述数组中存在的唯一整数值​​(因为对于每个这样的数组它们都可以不同)。

有没有办法在一行中分割多个蒙版。


  

蒙版是一个小数组。它包含不同的重复整数值

 ids = np.unique(masks) # ids = [0, 1, 2] for example
 # currently doing this
 mask0 = masks == ids[0]
 mask1 = masks == ids[1]
 mask2 = masks == ids[2]

是否有任何单行方法来获取一组所有二进制掩码。例如类似的东西。

 all_masks = masks == ids[:] # for example

1 个答案:

答案 0 :(得分:1)

您可以将两个数组广播成(id, height, width, channels)的形状,然后检查是否相等:

all_masks = (masks[np.newaxis] == ids[:, np.newaxis, np.newaxis, np.newaxis])

结果是,all_masks[i]是第(height, width, channels)个ID的i二进制掩码。