基本上我想映射多维numpy数组的每个值。输出应具有与输入相同的形状。
这就是我这样做的方式:
def f(x):
return x*x
input = np.arange(3*4*5).reshape(3,4,5)
output = np.array(list(map(f, input)))
print(output)
它有效,但感觉有点过于复杂(np.array
,list
,map
)。有更优雅的解决方案吗?
答案 0 :(得分:2)
只需在数组上调用您的函数:
f(input)
另外,最好不要为变量使用名称input
,因为它是内置的:
import numpy as np
def f(x):
return x*x
arr = np.arange(3*4*5).reshape(3,4,5)
print(np.alltrue(f(arr) == np.array(list(map(f, input)))))
输出:
True
如果功能更复杂:
def f(x):
return x+1 if x%2 else 2*x
使用vectorize
:
np.vectorize(f)(arr)
更好的是,总是尝试使用矢量化的NumPy函数,例如np.where
:
>>> np.alltrue(np.vectorize(f)(arr) == np.where(arr % 2, arr + 1, arr * 2))
True
本机NumPy版本要快得多:
%%timeit
np.vectorize(f)(arr)
34 µs ± 996 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
np.where(arr % 2, arr + 1, arr * 2)
5.16 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
对于较大的数组来说,这更为明显:
big_arr = np.arange(30 * 40 * 50).reshape(30, 40, 50)
%%timeit
np.vectorize(f)(big_arr)
15.5 ms ± 318 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
np.where(big_arr % 2, big_arr + 1, big_arr * 2)
797 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)