在Python中迭代3d数组中包含的2d数组

时间:2014-12-23 20:49:50

标签: python numpy multidimensional-array iteration vectorization

我已经看到很多问题要求以更快的方式迭代2d数组的每个元素,但我还没有找到一个好的方法迭代3d数组以便在每个2d数组上应用函数。例如:

from scipy.fftpack import dct
X = np.arange(10000, dtype=np.float32).reshape(-1,4,4)
np.array(map(dct, X))

这里,我将浏览维度为(625,4,4)的3d数组中包含的每个2d数组,并对每个4X4数组应用DCT(离散余弦变换)。我想知道是否有更合适的方法来实现这一目标。

由于

1 个答案:

答案 0 :(得分:6)

Numpy功能:

在这种情况下,由于dctnumpy函数,因此它具有内置的功能,可以将其应用于特定的轴。几乎所有的numpy函数都在完整的数组上运行,或者可以被告知在特定的轴(行或列)上运行。

所以只需利用axis函数的dct参数:

dct( X, axis=2)

您将获得相同的结果:

>>> ( dct(X, axis=2) == np.array(map(dct, X)) ).all()
True

map矩阵的情况下,使用(625,4,4)函数的速度也比其快35倍:

%timeit dct(X, axis=2)
1000 loops, best of 3: 157 µs per loop

%timeit np.array(map(dct, X))
100 loops, best of 3: 5.76 ms per loop    

常规Python函数:

在其他情况下,您可以使用np.vectorizenp.frompyfunc函数vectorize python函数。例如,如果您有一个执行标量操作的演示函数:

def foo(x): # gives an error if passed in an array
    return x**2

>>> X = np.arange(8, dtype=np.float32).reshape(-1,2,2)
>>> foo_arr = np.vectorize( foo)
>>> foo_arr(X)
array([[[  0.,   1.],
        [  4.,   9.]],

       [[ 16.,  25.],
        [ 36.,  49.]]])

讨论here也可能对您有所帮助。正如他们所说,矢量化你的非numpy函数实际上并没有让它更快。