我有一个大小为(W,H,C)的Numpy数组,其中“ C”是语义分割任务的许多类。我需要的是一个大小为(H,W)的Numpy数组,其中每个元素都是适合该像素的类的索引。
我找到了一种运行速度非常慢的方法。
masks = {list of 2d binary masks}
output_mask = np.zeros(width * height)
output_mask = output_mask.reshape(width, height)
for i in range(width):
for j in range(height):
class_id = 0
for mask in masks:
class_id += 1
if mask[i, j] == 1:
output_mask[i, j] = class_id
我希望可能会有更好的方法。谁能帮我吗?
答案 0 :(得分:0)
import numpy as np
arr = np.random.rand(10, 10, 3)
max_val = np.argmax(arr, axis=-1)
print(max_val.shape)