在较低维数组中拆分数组的最后一个维度

时间:2015-03-11 14:49:40

标签: numpy split

假设我们有一个NxMxD形状的数组。我想获得一个包含D NxM数组的列表。

正确的做法是:

np.dsplit(myarray, D)

但是,这会返回D NxMx1数组。

我可以通过以下方式达到预期的效果:

[myarray[..., i] for i in range(D)]

或者:

[np.squeeze(subarray) for subarray in np.dsplit(myarray, D)]

但是,我觉得需要执行额外的操作有点多余。我错过了任何返回所需结果的numpy函数吗?

2 个答案:

答案 0 :(得分:4)

尝试D.swapaxes(1,2).swapaxes(1,0)

>>>import numpy as np
>>>a = np.arange(24).reshape(2,3,4)
>>>a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

>>>[a[:,:,i] for i in range(4)]
[array([[ 0,  4,  8],
       [12, 16, 20]]),
 array([[ 1,  5,  9],
       [13, 17, 21]]),
 array([[ 2,  6, 10],
       [14, 18, 22]]),
 array([[ 3,  7, 11],
       [15, 19, 23]])]

>>>a.swapaxes(1,2).swapaxes(1,0)
array([[[ 0,  4,  8],
        [12, 16, 20]],

       [[ 1,  5,  9],
        [13, 17, 21]],

       [[ 2,  6, 10],
        [14, 18, 22]],

       [[ 3,  7, 11],
        [15, 19, 23]]])

编辑:正如ajcr所指出的那样(再次感谢),transpose命令更方便,因为可以使用

一步完成两个交换
D.transpose(2,0,1)

答案 1 :(得分:2)

np.dsplit使用np.array_split,其核心是:

sub_arys = []
sary = _nx.swapaxes(ary, axis, 0)
for i in range(Nsections):
    st = div_points[i]; end = div_points[i+1]
    sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0))

axis=-1,这相当于:

[x[...,i:(i+1)] for i in np.arange(x.shape[-1])]  # or
[x[...,[i]] for i in np.arange(x.shape[-1])]

考虑了单身维度。

所以你的

没有错误或效率低下
[x[...,i] for i in np.arange(x.shape[-1])]

实际上在快速时间测试中,dsplit的任何使用都很慢。这是一般性成本。因此添加squeeze相对便宜。

但是通过接受另一个答案,看起来你真的在寻找一个正确形状的数组,而不是一个数组列表。对于许多有意义的操作。当子阵列沿分割轴有多个“行”,或者甚至是“行”数不均匀时,split更有用。