给定索引列表,将任意维度的 numpy ndarry 切片为一维数组

时间:2021-03-22 23:46:38

标签: python numpy slice numpy-ndarray numpy-slicing

我有一个 numpy ndarray arrindices,一个指定特定条目的索引列表。为了具体起见,让我们采取:

arr = np.arange(2*3*4).reshape((2,3,4))
indices= [1,0,3]

我有代码可以通过 arr 观察除一个索引之外的所有索引 n

arr[:, indices[1], indices[2]]  # n = 0
arr[indices[0], :, indices[2]]  # n = 1
arr[indices[0], indices[1], :]  # n = 2

我想更改我的代码以遍历 n 并支持任意维度的 arr

我查看了文档中的 indexing routines 条目并找到了有关 slice()np.s_() 的信息。我能够将一些像我想要的那样工作的东西组合在一起:

def make_custom_slice(n, indices):
    s = list()
    for i, idx in enumerate(indices):
        if i == n:
            s.append(slice(None))
        else:
            s.append(slice(idx, idx+1))
    return tuple(s)


for n in range(arr.ndim):
    np.squeeze(arr[make_custom_slice(n, indices)])

其中 np.squeeze 用于删除长度为 1 的轴。没有它,生成的数组的形状为 (arr.shape[n],1,1,...) 而不是 (arr.shape[n],)

有没有更惯用的方法来完成这项任务?

1 个答案:

答案 0 :(得分:1)

对上述解决方案的一些改进(可能仍然存在单行或更高效的解决方案):

def make_custom_slice(n, indices):
    s = indices.copy()
    s[n] = slice(None)
    return tuple(s)


for n in range(arr.ndim):
    print(arr[make_custom_slice(n, indices)])

一个整数值 idx 可用于替换切片对象 slice(idx, idx+1)。因为大多数索引都是直接复制的,所以从索引的副本开始,而不是从头开始构建列表。

以这种方式构建时,arr[make_custom_slice(n, indices) 的结果具有预期的维度,而 np.squeeze 是不必要的。

相关问题