合并i ^ {th}轴之前和之后的轴

时间:2017-12-02 02:43:58

标签: python numpy

我想为arr (before, at, after)的{​​{1}}个axis重塑一个numpy数组arr。如何更快地完成这项工作?

轴已经标准化:0 <= axis < arr.ndim

程序:

import numpy as np
def f(arr, axis):
    shape = arr.shape
    before = int(np.product(shape[:axis]))
    at = shape[axis]
    return arr.reshape(before, at, -1)

测试:

a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
print(f(a, 2).shape)

结果:

(6, 4, 5)

2 个答案:

答案 0 :(得分:2)

shape是一个元组,所需的结果也是一个元组。转换为数组或从数组转换为使用np.prod或其他一些数组函数需要时间。因此,如果我们可以使用普通的Python代码执行相同操作,我们可以节省时间

例如shape

In [309]: shape
Out[309]: (2, 3, 4, 5)
In [310]: np.prod(shape)
Out[310]: 120
In [311]: functools.reduce(operator.mul,shape)
Out[311]: 120

In [312]: timeit np.prod(shape)
13.6 µs ± 30.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [313]: timeit functools.reduce(operator.mul,shape)
647 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

python版本明显更快。我必须导入functoolsoperator才能获得相当于sum(Python3)的乘法。

或者获得新的形状元组:

In [314]: axis=2
In [315]: (functools.reduce(operator.mul,shape[:axis]),shape[axis],-1)
Out[315]: (6, 4, -1)
In [316]: timeit (functools.reduce(operator.mul,shape[:axis]),shape[axis],-1)
739 ns ± 30.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

比较建议的reduceat

In [318]: tuple(np.multiply.reduceat(shape, (0, axis, axis+1)))
Out[318]: (6, 4, 5)
In [319]: timeit tuple(np.multiply.reduceat(shape, (0, axis, axis+1)))
11.3 µs ± 21.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

答案 1 :(得分:1)

如果您的轴位于中间位置,则可以使用np.multiply.reduceat

 shape = (2, 3, 4, 5, 6)
 axis = 2
 np.multiply.reduceat(shape, (0, axis, axis+1))
 # array([ 6,  4, 30])
 axis = 3
 np.multiply.reduceat(shape, (0, axis, axis+1))
 # array([24,  5,  6])

如果您想要第0轴或最后一轴,则必须使用特殊情况。