所以我有一个大小为(H x W x C)的图像,其中C是一些通道。挑战是获得一个新的图像J,同样大小(H x W x C),其中J [i,j]仅包含I [i,j]中的最多n个条目。
等效地,考虑迭代I中的每个图像像素并将除最高n个条目之外的所有图像归零。
我尝试了什么:
# NOTE: bone_weight_matrix is a matrix of size (256 x 256 x 43)
argsort_four = np.argsort(bone_weight_matrix, axis=2)[:, :, -4:]
# For each pixel, retain only the top four influencing bone weights
proc_matrix = np.zeros(bone_weight_matrix.shape)
for i in range(bone_weight_matrix.shape[0]):
for j in range(bone_weight_matrix.shape[1]):
proc_matrix[i, j, argsort_four[i, j]] = bone_weight_matrix[i, j, argsort_four[i, j]]
return proc_matrix
问题是这种方法似乎超级慢,并且感觉不是非常pythonic。任何建议都会很棒。
干杯。
答案 0 :(得分:1)
n
个元素基本上会涉及两个步骤:
使用np.argparition
将指定轴上的n
索引保留。
初始化一个零数组,并使用那些先前获得的advanced-indexing
索引从输入数组中选择,并分配到零数组。
让我们尝试解决一个通用问题,该问题可以在指定的轴上选择n
个元素,并且能够保持最大的n
以及最小的n
元件。
实现看起来像这样 -
def keep(ar, n, axis=-1, order='largest'):
axis = np.core.multiarray.normalize_axis_index(axis, ar.ndim)
slice_l = [slice(None, None, None)]*ar.ndim
if order=='largest':
slice_l[axis] = slice(-n,None,None)
idx = np.argpartition(ar, kth=-n, axis=axis)[slice_l]
elif order=='smallest':
slice_l[axis] = slice(None,n,None)
idx = np.argpartition(ar, kth=n, axis=axis)[slice_l]
else:
raise Exception('Invalid order value')
grid = np.ogrid[tuple(map(slice, ar.shape))]
grid[axis] = idx
out = np.zeros_like(ar)
out[grid] = ar[grid]
return out
示例运行
输入数组:
In [208]: np.random.seed(0)
...: I = np.random.randint(11,99,(2,2,6))
In [209]: I
Out[209]:
array([[[55, 58, 75, 78, 78, 20],
[94, 32, 47, 98, 81, 23]],
[[69, 76, 50, 98, 57, 92],
[48, 36, 88, 83, 20, 31]]])
沿最后一个轴保留最大2
个元素:
In [210]: keep(I, n=2, axis=-1, order='largest')
Out[210]:
array([[[ 0, 0, 0, 78, 78, 0],
[94, 0, 0, 98, 0, 0]],
[[ 0, 0, 0, 98, 0, 92],
[ 0, 0, 88, 83, 0, 0]]])
沿第一轴保持最大1
元素:
In [211]: keep(I, n=1, axis=1, order='largest')
Out[211]:
array([[[ 0, 58, 75, 0, 0, 0],
[94, 0, 0, 98, 81, 23]],
[[69, 76, 0, 98, 57, 92],
[ 0, 0, 88, 0, 0, 0]]])
沿最后一轴保留最小2
个元素:
In [212]: keep(I, n=2, axis=-1, order='smallest')
Out[212]:
array([[[55, 0, 0, 0, 0, 20],
[ 0, 32, 0, 0, 0, 23]],
[[ 0, 0, 50, 0, 57, 0],
[ 0, 0, 0, 0, 20, 31]]])