使用2d数组作为4d数组的索引

时间:2020-07-09 14:53:02

标签: arrays numpy pytorch indices

我有一个来自tensor.max()操作的Numpy 2D数组(4000,8000),它存储4D数组(30,4000,8000,3)的第一维索引。我需要获得一个(4000,8000,3)数组,该数组在这组图像上使用索引,并提取2D max数组中每个位置的像素。

A = np.random.randint( 0, 29, (4000,8000), dtype=int)
B = np.random.randint(0,255,(30,4000,8000,3),dtype=np.uint8)

final = np.zeros((B.shape[1],B.shape[2],3))

r = 0
c = 0
for row in A:
  c = 0
  for col in row:
    x = A[r,c]
    final[r,c] = B[x,r,c]
    c=c+1
  r=r+1

print(final.shape)

有矢量化的方法吗?我正在使用循环来应对RAM使用情况。谢谢

1 个答案:

答案 0 :(得分:0)

您可以使用np.take_along_axis

首先让我们创建一些数据(您应该提供一个reproducible example):

>>> N, H, W, C = 10, 20, 30, 3
>>> arr = np.random.randn(N, H, W, C)
>>> indices = np.random.randint(0, N, size=(H, W))

然后,我们将使用np.take_along_axis。但是为此,indices数组必须与arr数组具有相同的形状。因此,我们使用np.newaxis在形状不匹配的地方插入轴。

>>> res = np.take_along_axis(arr, indices[np.newaxis, ..., np.newaxis], axis=0)

它已经提供了可用的输出,但是在第一轴上具有单例尺寸:

>>> res.shape
(1, 20, 30, 3)

所以我们可以挤压它:

>>> res = np.squeeze(res)

>>> res.shape
(20, 30, 3)

并最终检查数据是否如我们所愿:

>>> np.all(res[0, 0] == arr[indices[0, 0], 0, 0])
True

>>> np.all(res[5, 3] == arr[indices[5, 3], 5, 3])
True