迭代通过多维数组的所有1维子阵列

时间:2011-04-16 15:01:20

标签: python multidimensional-array numpy

在python中迭代遍历n维数组的所有一维子数组的最快方法是什么。

例如,考虑三维数组:

import numpy as np 
a = np.arange(24)
a = a.reshape(2,3,4)

迭代器所需的产量序列是:

a[:,0,0]
a[:,0,1]
..
a[:,2,3]
a[0,:,0]
..
a[1,:,3]
a[0,0,:]
..
a[1,2,:]

3 个答案:

答案 0 :(得分:10)

这是一个这样的迭代器的紧凑实现:

def iter1d(a):
    return itertools.chain.from_iterable(
        numpy.rollaxis(a, axis, a.ndim).reshape(-1, dim)
        for axis, dim in enumerate(a.shape))

这将按照您在帖子中提供的顺序生成子数组:

for x in iter1d(a):
    print x

打印

[ 0 12]
[ 1 13]
[ 2 14]
[ 3 15]
[ 4 16]
[ 5 17]
[ 6 18]
[ 7 19]
[ 8 20]
[ 9 21]
[10 22]
[11 23]
[0 4 8]
[1 5 9]
[ 2  6 10]
[ 3  7 11]
[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]
[0 1 2 3]
[4 5 6 7]
[ 8  9 10 11]
[12 13 14 15]
[16 17 18 19]
[20 21 22 23]

这里的技巧是迭代所有轴,并且每个轴将数组重新整形为二维数组,其中的行是所需的一维子数组。

答案 1 :(得分:5)

可能有一种更有效的方式,但这应该有用......

import itertools
import numpy as np

a = np.arange(24)
a = a.reshape(2,3,4)

colon = slice(None)
dimensions = [range(dim) + [colon] for dim in a.shape]

for dim in itertools.product(*dimensions):
    if dim.count(colon) == 1:
        print a[dim]

这会产生(我只留下一些微不足道的代码来打印这个的左侧......):

a[0,0,:] -->  [0 1 2 3]
a[0,1,:] -->  [4 5 6 7]
a[0,2,:] -->  [ 8  9 10 11]
a[0,:,0] -->  [0 4 8]
a[0,:,1] -->  [1 5 9]
a[0,:,2] -->  [ 2  6 10]
a[0,:,3] -->  [ 3  7 11]
a[1,0,:] -->  [12 13 14 15]
a[1,1,:] -->  [16 17 18 19]
a[1,2,:] -->  [20 21 22 23]
a[1,:,0] -->  [12 16 20]
a[1,:,1] -->  [13 17 21]
a[1,:,2] -->  [14 18 22]
a[1,:,3] -->  [15 19 23]
a[:,0,0] -->  [ 0 12]
a[:,0,1] -->  [ 1 13]
a[:,0,2] -->  [ 2 14]
a[:,0,3] -->  [ 3 15]
a[:,1,0] -->  [ 4 16]
a[:,1,1] -->  [ 5 17]
a[:,1,2] -->  [ 6 18]
a[:,1,3] -->  [ 7 19]
a[:,2,0] -->  [ 8 20]
a[:,2,1] -->  [ 9 21]
a[:,2,2] -->  [10 22]
a[:,2,3] -->  [11 23]

这里的关键是使用(例如)a索引a[0,0,:]等同于使用a[(0,0,slice(None))]索引a。 (这只是通用的python切片,没有特定的numpy特定。为了向自己证明,你可以编写一个只有__getitem__的虚拟类,并打印在索引虚拟类的实例时传入的内容。)。

所以,我们想要的是0到nx,0到ny,0到nz等的每种可能组合以及每个轴的None

但是,我们需要1D数组,因此我们需要过滤掉多于或少于一个None的任何内容(即我们不希望a[:,:,:]a[0,:,:],{{1等等)。

希望无论如何都有道理......

编辑:我假设确切的顺序并不重要......如果您需要在问题中列出的确切顺序,则需要修改此...

答案 2 :(得分:0)

您的朋友是slice()个对象,numpy的ndarray.__getitem__()方法,可能是itertools.chain.from_iterable