部分索引多维数组

时间:2018-11-07 17:58:21

标签: python numpy

import numpy as np
arr = np.random.rand(50,3,3,3,16)
ids = (0,0,2,10)
b = arr[:, ids]  # don't work
b = arr[:, *ids]  # don't work
b = arr[:][ids]  # don't work
b = arr[:, tuple(ids)]  # don't work
b = arr[: + ids]  # don't work, obviously..
# b = arr[:,0,0,2,10].shape  # works (desired outcome)

我知道有关此问题,例如Tuple as index of multidimensional arrayUnpacking tuples/arrays/lists as indices for Numpy Arrays,但这些问题都不适合我的情况。基本上,我想在第一个轴上索引所有其余轴中指定的“列”中的所有内容(请参阅代码的最后一行)。 在这种情况下,所需的输出形状应为(50,)

但是我想用一个元组/ ID列表进行索引,因为我需要遍历它们,例如:

all_ids = ((0,0,0,2), (0,0,0,6), (1,1,0,2), (1,1,0,6),
           (2,2,0,2), (2,2,0,6), (2,2,2,2), (2,2,2,6))
c = 0
for id in all_ids:
    c += arr[:, id].sum() 

1 个答案:

答案 0 :(得分:2)

slice(None)添加到ids中的第一个维度,然后添加子集:

arr[(slice(None),) + ids].shape
# (50,)

其中:

(slice(None),) + ids
# (slice(None, None, None), 0, 0, 2, 10)

通知slice(None, None, None)等效于:,即全部切片。您可以阅读docs on using slice object for indexing here