我有一个来自tensor.max()
操作的Numpy 2D数组(4000,8000),它存储4D数组(30,4000,8000,3)的第一维索引。我需要获得一个(4000,8000,3)数组,该数组在这组图像上使用索引,并提取2D max数组中每个位置的像素。
A = np.random.randint( 0, 29, (4000,8000), dtype=int)
B = np.random.randint(0,255,(30,4000,8000,3),dtype=np.uint8)
final = np.zeros((B.shape[1],B.shape[2],3))
r = 0
c = 0
for row in A:
c = 0
for col in row:
x = A[r,c]
final[r,c] = B[x,r,c]
c=c+1
r=r+1
print(final.shape)
有矢量化的方法吗?我正在使用循环来应对RAM使用情况。谢谢
答案 0 :(得分:0)
您可以使用np.take_along_axis
。
首先让我们创建一些数据(您应该提供一个reproducible example):
>>> N, H, W, C = 10, 20, 30, 3
>>> arr = np.random.randn(N, H, W, C)
>>> indices = np.random.randint(0, N, size=(H, W))
然后,我们将使用np.take_along_axis
。但是为此,indices
数组必须与arr
数组具有相同的形状。因此,我们使用np.newaxis
在形状不匹配的地方插入轴。
>>> res = np.take_along_axis(arr, indices[np.newaxis, ..., np.newaxis], axis=0)
它已经提供了可用的输出,但是在第一轴上具有单例尺寸:
>>> res.shape
(1, 20, 30, 3)
所以我们可以挤压它:
>>> res = np.squeeze(res)
>>> res.shape
(20, 30, 3)
并最终检查数据是否如我们所愿:
>>> np.all(res[0, 0] == arr[indices[0, 0], 0, 0])
True
>>> np.all(res[5, 3] == arr[indices[5, 3], 5, 3])
True