将自定义函数应用于多维numpy数组,保持相同的形状

时间:2018-01-12 00:29:18

标签: python numpy

基本上我想映射多维numpy数组的每个值。输出应具有与输入相同的形状。

这就是我这样做的方式:

def f(x):
    return x*x

input = np.arange(3*4*5).reshape(3,4,5)
output = np.array(list(map(f, input)))
print(output)

它有效,但感觉有点过于复杂(np.arraylistmap)。有更优雅的解决方案吗?

1 个答案:

答案 0 :(得分:2)

只需在数组上调用您的函数:

f(input)

另外,最好不要为变量使用名称input,因为它是内置的:

import numpy as np

def f(x):
    return x*x

arr = np.arange(3*4*5).reshape(3,4,5)
print(np.alltrue(f(arr) == np.array(list(map(f, input)))))

输出:

True

如果功能更复杂:

def f(x):
    return x+1 if x%2 else 2*x

使用vectorize

np.vectorize(f)(arr)

更好的是,总是尝试使用矢量化的NumPy函数,例如np.where

>>> np.alltrue(np.vectorize(f)(arr) == np.where(arr % 2, arr + 1, arr * 2))
True

本机NumPy版本要快得多:

%%timeit
np.vectorize(f)(arr)
34 µs ± 996 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


%%timeit
np.where(arr % 2, arr + 1, arr * 2)
5.16 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对于较大的数组来说,这更为明显:

big_arr = np.arange(30 * 40 * 50).reshape(30, 40, 50)

%%timeit
np.vectorize(f)(big_arr)
15.5 ms ± 318 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
np.where(big_arr % 2, big_arr + 1, big_arr * 2)
797 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)