如何将3D数组转换为argmax的2D数组?

时间:2019-08-25 07:14:44

标签: python numpy numpy-ndarray

我有一个大小为(W,H,C)的Numpy数组,其中“ C”是语义分割任务的许多类。我需要的是一个大小为(H,W)的Numpy数组,其中每个元素都是适合该像素的类的索引。

我找到了一种运行速度非常慢的方法。

masks = {list of 2d binary masks}
output_mask = np.zeros(width * height)
output_mask = output_mask.reshape(width, height)

for i in range(width):
    for j in range(height):
        class_id = 0
        for mask in masks:
            class_id += 1
            if mask[i, j] == 1:
                output_mask[i, j] = class_id

我希望可能会有更好的方法。谁能帮我吗?

1 个答案:

答案 0 :(得分:0)

import numpy as np

arr = np.random.rand(10, 10, 3)
max_val = np.argmax(arr, axis=-1)

print(max_val.shape)