为每个图像像素有效地将所有但最大的n个元素归零

时间:2018-01-13 16:34:33

标签: python numpy

所以我有一个大小为(H x W x C)的图像,其中C是一些通道。挑战是获得一个新的图像J,同样大小(H x W x C),其中J [i,j]仅包含I [i,j]中的最多n个条目。

等效地,考虑迭代I中的每个图像像素并将除最高n个条目之外的所有图像归零。

我尝试了什么:

# NOTE: bone_weight_matrix is a matrix of size (256 x 256 x 43)
argsort_four = np.argsort(bone_weight_matrix, axis=2)[:, :, -4:]

# For each pixel, retain only the top four influencing bone weights 
    proc_matrix = np.zeros(bone_weight_matrix.shape)

    for i in range(bone_weight_matrix.shape[0]):
        for j in range(bone_weight_matrix.shape[1]):
                proc_matrix[i, j, argsort_four[i, j]] = bone_weight_matrix[i, j, argsort_four[i, j]]

    return proc_matrix

问题是这种方法似乎超级慢,并且感觉不是非常pythonic。任何建议都会很棒。

干杯。

1 个答案:

答案 0 :(得分:1)

通用案例:沿轴保留最大或最小n个元素

基本上会涉及两个步骤:

  • 使用np.argparition将指定轴上的n索引保留。

  • 初始化一个零数组,并使用那些先前获得的advanced-indexing索引从输入数组中选择,并分配到零数组。

让我们尝试解决一个通用问题,该问题可以在指定的轴上选择n个元素,并且能够保持最大的n以及最小的n元件。

实现看起来像这样 -

def keep(ar, n, axis=-1, order='largest'):
    axis = np.core.multiarray.normalize_axis_index(axis, ar.ndim)
    slice_l = [slice(None, None, None)]*ar.ndim

    if order=='largest':
        slice_l[axis] = slice(-n,None,None)
        idx = np.argpartition(ar, kth=-n, axis=axis)[slice_l]
    elif order=='smallest':
        slice_l[axis] = slice(None,n,None)
        idx = np.argpartition(ar, kth=n, axis=axis)[slice_l]
    else:
        raise Exception('Invalid order value')

    grid = np.ogrid[tuple(map(slice, ar.shape))]
    grid[axis] = idx
    out = np.zeros_like(ar)
    out[grid] = ar[grid]
    return out

示例运行

输入数组:

In [208]: np.random.seed(0)
     ...: I = np.random.randint(11,99,(2,2,6))

In [209]: I
Out[209]: 
array([[[55, 58, 75, 78, 78, 20],
        [94, 32, 47, 98, 81, 23]],

       [[69, 76, 50, 98, 57, 92],
        [48, 36, 88, 83, 20, 31]]])

沿最后一个轴保留最大2个元素:

In [210]: keep(I, n=2, axis=-1, order='largest')
Out[210]: 
array([[[ 0,  0,  0, 78, 78,  0],
        [94,  0,  0, 98,  0,  0]],

       [[ 0,  0,  0, 98,  0, 92],
        [ 0,  0, 88, 83,  0,  0]]])

沿第一轴保持最大1元素:

In [211]: keep(I, n=1, axis=1, order='largest')
Out[211]: 
array([[[ 0, 58, 75,  0,  0,  0],
        [94,  0,  0, 98, 81, 23]],

       [[69, 76,  0, 98, 57, 92],
        [ 0,  0, 88,  0,  0,  0]]])

沿最后一轴保留最小2个元素:

In [212]: keep(I, n=2, axis=-1, order='smallest')
Out[212]: 
array([[[55,  0,  0,  0,  0, 20],
        [ 0, 32,  0,  0,  0, 23]],

       [[ 0,  0, 50,  0, 57,  0],
        [ 0,  0,  0,  0, 20, 31]]])