必须有更多的pythonic方式:
r = np.arange(100)
results = []
for i in r:
for j in r:
for k in r:
for l in r:
#Here f is some predefined function
if f(i,j,k,l) < 5.0:
results.append(f(i,j,k,l))
我确信使用数组可以以某种方式简化这一点,我只是不确定如何。谢谢!
答案 0 :(得分:5)
使用itertools
笛卡尔积:
import itertools
r = np.arange(100)
results = []
for (i,j,k,l) in itertools.product(r,repeat=4):
if f(i,j,k,l) < 5.0:
results.append(f(i,j,k,l))
或者更紧凑的方式,使用列表理解:
[ f(i,j,k,l) for (i,j,k,l) in itertools.product(r,repeat=4) if f(i,j,k,l) < 5.0 ]
答案 1 :(得分:0)
使用NumPy的fromfunction
和boolean indexing可以避免for循环和if语句。提议的方法包含在comb_np(n)
中,而@Ohad Eytan提出的基于itertools
的解决方案包含在comb_it(n)
中。为方便起见,每个for循环的迭代次数(在您的示例中为100
)作为参数传递给两个函数(n
)。为了比较分析这两种方法,我使用了一个简单的政治函数f(x, y, z, t)
。
from numpy import fromfunction
from itertools import product
from numpy import arange
def f(x, y, z, t):
return x + 2*y + 3*z + 4*t
def comb_np(n):
arr = fromfunction(f, (n,)*4)
return arr[arr < 5.0]
def comb_it(n):
return [f(i,j,k,l) for (i,j,k,l) in product(arange(n),repeat=4) if f(i,j,k,l) < 5.0]
示例运行:
In [1302]: comb_np(10)
Out[1302]: array([ 0., 4., 3., 2., 4., 1., 4., 3., 2., 4., 3., 4.])
In [1303]: comb_it(10)
Out[1303]: [0, 4, 3, 2, 4, 1, 4, 3, 2, 4, 3, 4]
两种方法都产生相同的结果。到现在为止还挺好。但现在让我们评估效率方面是否存在差异:
In [1304]: import timeit
In [1305]: timeit.timeit("comb_np(10)", setup="from numpy import fromfunction;from __main__ import comb_np, f", number=1)
Out[1305]: 0.0008685288485139608
In [1306]: timeit.timeit("comb_it(10)", setup="from itertools import product;from numpy import arange;from __main__ import comb_it, f", number=1)
Out[1306]: 0.05153228418203071
In [1307]: timeit.timeit("comb_np(100)", setup="from numpy import fromfunction;from __main__ import comb_np, f", number=1)
Out[1307]: 3.4775129712652415
In [1308]: timeit.timeit("comb_it(100)", setup="from itertools import product;from numpy import arange;from __main__ import comb_it, f", number=1)
Out[1308]: 354.3811327822914
从上面的结果可以清楚地看出,在这个特殊问题中,NumPy的矢量化代码比迭代器大约高出两个数量级。
有趣的是,我发现只需用内置函数arange
替换NumPy的range
,comb_it
的性能就会大大提高:
def comb_it2(n):
return [f(i,j,k,l) for (i,j,k,l) in product(range(n),repeat=4) if f(i,j,k,l) < 5.0]
结果:
In [1381]: comb_it2(10)
Out[1381]: [0, 4, 3, 2, 4, 1, 4, 3, 2, 4, 3, 4]
In [1382]: timeit.timeit("comb_it2(10)", setup="from itertools import product;from __main__ import comb_it2, f", number=1)
Out[1382]: 0.009133451094385237
In [1383]: timeit.timeit("comb_it2(100)", setup="from itertools import product;from __main__ import comb_it2, f", number=1)
Out[1383]: 32.556062019226374