优化在单热像素标签中创建空类

时间:2018-01-17 15:36:50

标签: python numpy optimization computer-vision numpy-broadcasting

我正在为图像分割模型准备数据。我有每个像素5个类,不累积覆盖整个图像所以我想创建一个'null'类作为第6类。现在我有一个单热编码的ndarray和一个解决方案,它使我想要优化的一堆Python调用。 我的草图代码现在:

arrs.shape
(25, 25, 5)

null_class = np.zeros(arrs.shape[:-1])
for i in range(arrs.shape[0]):
    for j in range(arrs.shape[1]):
        if not np.any(arrs[i][j] == 1):
            null_class[i][j] = 1

理想情况下,我找到了一种计算空示例的几行和更高性能的方法 - 我的实际训练数据来自20K x 20K图像,我想一次计算和存储所有数据。有什么建议?

1 个答案:

答案 0 :(得分:0)

我相信你可以通过numpy.wherenumpy.all的组合来实现这一目标。使用all检查最后一个维度上的所有零将为您提供一个True的布尔数组,其中null_class应为1。为了显示,我将使用(2,2,5)数组。

arr = np.random.randint(0, 2, size=(2,2,5))
null_class = np.zeros(arr.shape[:-1])
arr[0, 0] = [0, 0, 0, 0, 0]
arr
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1]],

       [[0, 0, 1, 0, 0],
        [0, 1, 1, 1, 0]]])
np.all(arr[:, :] == 0, axis=2)
array([[ True, False],
       [False, False]], dtype=bool)
np.where(np.all(arr[:, :] == 0, axis=2))
(array([0]), array([0]))
null_class[np.where(np.all(arr[:, :] == 0, axis=2)] = 1
null_class
array([[ 1.,  0.],
       [ 0.,  0.]])