如何遍历numpy数组的第n个维度?

时间:2019-06-05 12:17:31

标签: python numpy

我通常将任意形状的numpy数组连接起来,以使我的代码更整洁,但是,我似乎很难以pythonesque的方式对其进行迭代。

让我们考虑一个4维数组x(因此len(x.shape) = 4),并且我要迭代的索引是2,我通常使用的天真的解决方案是这样的

y = np.array([my_operation(x[:, :, i, :])
              for i in range(x.shape[2])])

我正在寻找更具可读性的东西,因为拥有这么多“:”很烦人,并且x尺寸的任何更改都需要重写我的代码的一部分。诸如此类的魔术

y = np.array([my_operation(z) for z in magic_function(x, 2)])

是否有一个numpy方法可以让我遍历数组的任意维?

3 个答案:

答案 0 :(得分:0)

一种可能的解决方案是使用dict()。

您可以做的是:

x = dict()
x['param1'] = [1, 1, 1, 1]
x['param2'] = [2, 2, 2, 2]

print(x['param1']) 
# > [1, 1, 1, 1]

答案 1 :(得分:0)

我不知道这样做的任何标准方法。无论如何,你的把戏很好。我们可以对其进行详细说明,然后为您正在寻找的“魔术功能”提供一个实现:

def magic_function(x, n):
    slices = [slice(w) for w in x.shape]
    for i in range(x.shape[n]):
        slices[n] = i
        z = x[tuple(slices)]
        yield z

答案 2 :(得分:0)

您可以暂时将所需的轴移到前面,然后在数组上进行迭代。然后将轴移回:

x = np.moveaxis(x, 2, 0)
x = np.array([my_operation(sub_x) for sub_x in x])
x = np.moveaxis(x, 0, 2)