Python - Sum 4D数组

时间:2014-07-19 14:15:21

标签: python arrays numpy

给定4D数组M: (m, n, r, r),我如何求和所有m * n内部矩阵(形状(r, r))以获得新的形状矩阵{{1} }?

例如,

(r * r)

我希望结果应该是

    M [[[[ 4,  1],
         [ 2,  1]],

        [[ 8,  2],
         [ 4,  2]]],

       [[[ 8,  2],
         [ 4,  2]],

        [[ 12, 3],
         [ 6,  3]]]]

3 个答案:

答案 0 :(得分:5)

您可以使用einsum

In [21]: np.einsum('ijkl->kl', M)
Out[21]: 
array([[32,  8],
       [16,  8]])

其他选项包括将前两个轴重塑为一个轴,然后调用sum

In [24]: M.reshape(-1, 2, 2).sum(axis=0)
Out[24]: 
array([[32,  8],
       [16,  8]])

或两次调用sum方法:

In [26]: M.sum(axis=0).sum(axis=0)
Out[26]: 
array([[32,  8],
       [16,  8]])

但使用np.einsum的速度更快:

In [22]: %timeit np.einsum('ijkl->kl', M)
100000 loops, best of 3: 2.42 µs per loop

In [25]: %timeit M.reshape(-1, 2, 2).sum(axis=0)
100000 loops, best of 3: 5.69 µs per loop

In [43]: %timeit np.sum(M, axis=(0,1))
100000 loops, best of 3: 6.08 µs per loop

In [33]: %timeit sum(sum(M))
100000 loops, best of 3: 8.18 µs per loop

In [27]: %timeit M.sum(axis=0).sum(axis=0)
100000 loops, best of 3: 9.83 µs per loop

警告:由于许多因素(操作系统,NumPy版本,NumPy库,硬件等),timeit基准测试可能会有很大差异。各种方法的相对性能有时也取决于M的大小。因此,在M更接近实际用例的情况下进行自己的基准测试是值得的。

例如,对于稍大一些的数组M,调用sum方法两次可能是最快的:

In [34]: M = np.random.random((100,100,2,2))

In [37]: %timeit M.sum(axis=0).sum(axis=0)
10000 loops, best of 3: 59.9 µs per loop

In [39]: %timeit np.einsum('ijkl->kl', M)
10000 loops, best of 3: 99 µs per loop

In [40]: %timeit np.sum(M, axis=(0,1))
10000 loops, best of 3: 182 µs per loop

In [36]: %timeit M.reshape(-1, 2, 2).sum(axis=0)
10000 loops, best of 3: 184 µs per loop

In [38]: %timeit sum(sum(M))
1000 loops, best of 3: 202 µs per loop

答案 1 :(得分:3)

到目前为止,最近的numpy(版本1.7或更新版本)中最简单的是:

np.sum(M, axis=(0, 1))

这不会构建一个中间数组,因为对np.sum的重复调用会。

答案 2 :(得分:1)

import numpy as np
l = np.array([[[[ 4,  1],
                [ 2,  1]],
               [[ 8,  2],
                [ 4,  2]]],
              [[[ 8,  2],
                [ 4,  2]],
               [[12,  3],
                [ 6,  3]]]])
sum(sum(l))

输出

array([[32,  8],
       [16,  8]])