向量化python代码以修改图像分割中使用的蒙版

时间:2019-05-16 11:16:31

标签: python numpy

我正在处理图像分割挑战。我有具有5个标签(0、1、2、3、4)的蒙版,其中一个这样的蒙版(2D矩阵)的布局为:

[0 0 0 0 0 0
 0 0 1 1 2 2
 1 1 1 1 2 2
 2 2 2 2 2 2
 3 3 3 3 3 3
 4 4 4 4 4 4]

我想合并几个类,使修改后的蒙版看起来像:

[0 0 0 0 0 0
 0 0 0 0 1 1
 0 0 0 0 1 1
 1 1 1 1 1 1
 2 2 2 2 2 2
 2 2 2 2 2 2]

合并0和1到0。 从2更改为1。 将3和4合并为2。

我实现了一个循环版本,因为我的蒙版尺寸为(601,462,951),所以要花很多时间。

for i in range(0, dim.shape[0]):
  for j in range(0, dim.shape[1]):
    for k in range(0, dim.shape[2]):
      if dim[i, j, k] in (0, 2):
        dim[i, j, k] = 1

      if dim[i, j, k] == 3:
        dim[i, j, k] = 2

      if dim[i, j, k] in (4, 8):
        dim[i, j, k] = 3

      if dim[i, j, k] == 9:
        dim[i, j, k] = 4

我找不到任何方法来对代码进行矢量化处理,以便删除循环。

3 个答案:

答案 0 :(得分:1)

您可以使用np.select来获得简洁的解决方案,该解决方案使您可以从choicelist中选择条件列表:

np.select([a==1, a==2, (a==3)|(a==4)], [0,1,2])

array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 1, 1],
       [1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2],
       [2, 2, 2, 2, 2, 2]])

答案 1 :(得分:0)

您可以这样做:

@section('page_tagline', __('pages.home.tagline'))

其中a[a == 1] = 0 a[a == 2] = 1 a[(a == 3) | (a == 4)] = 2 是您的numpy数组。

答案 2 :(得分:0)

使用mapper

您可以创建一个映射数组,然后简单地用输入数组建立索引就可以为我们提供所需的输出-

mapper = np.array([0,0,1,2,2])
out = mapper[a] # a is input array

具有相同给定形状的所有已发布解决方案的时间-(601, 462, 951))-

In [60]: np.random.seed(0)
    ...: a = np.random.randint(0,5,(601, 462, 951))

# @yatu's soln
In [61]: %timeit np.select([a==1, a==2, (a==3)|(a==4)], [0,1,2])
1 loop, best of 3: 5 s per loop

# Posted in this post
In [62]: %%timeit
    ...: mapper = np.array([0,0,1,2,2])
    ...: out = mapper[a]
1 loop, best of 3: 849 ms per loop

# @Austin's soln
In [63]: %%timeit
    ...: a[a == 1] = 0
    ...: a[a == 2] = 1
    ...: a[(a == 3) | (a == 4)] = 2
1 loop, best of 3: 1.04 s per loop

具有较低精度dtype的进一步增强功能

由于输出将带有标签-0,1,2,因此我们可以安全地将UINT8用作输出数据类型,并获得巨大的性能提升。因此,它将是-

mapper = np.array([0,0,1,2,2],dtype=np.uint8)
out = mapper[a]

时间-

In [66]: np.random.seed(0)
    ...: a = np.random.randint(0,5,(601, 462, 951))

In [67]: %%timeit
    ...: mapper = np.array([0,0,1,2,2],dtype=np.uint8)
    ...: out = mapper[a]
1 loop, best of 3: 380 ms per loop