我认为我的代码无效,因为我有一个数组列表。是否有不同的方法可以将final_list
数组列表构造为矩阵,以便numba
接受它?
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
@nb.jit(nopython=True)
def logi(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
for i in range(N_SPLITS):
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
plt.savefig('bifig.pdf')
File "logi.py", line 18, in <module>
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
File "/usr/local/lib/python2.7/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
numba.errors.TypingError: Caused By:
Traceback (most recent call last):
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 235, in run
stage()
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 449, in stage_nopython_frontend
self.locals)
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 805, in type_inference_stage
infer.propagate()
File "/usr/local/lib/python2.7/site-packages/numba/typeinfer.py", line 767, in propagate
raise errors[0]
TypingError: Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
Failed at nopython (nopython frontend)
Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
答案 0 :(得分:2)
您的代码存在许多问题导致numba jit-compiler出现问题:
np.sort
无效,也无法在2D数组上使用它
(见:numpy supported
features)
x
从浮点数更改为数组。 Numba要求整个函数的类型一致性
下面是一个numba函数,它以nopython
模式编译并产生相同的结果。基本上我预先分配存储阵列,因为事先知道大小然后按列排序。不幸的是numba
并没有真正好的排序实现,所以你不会获得非常大的加速。您可以进行其他性能调整更改。另请注意,在绘图部分的每个循环中调用logi
然后拉出单个值是没有意义的。只需计算一次数组,然后选出所需的值。
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
%matplotlib inline
def logi_orig(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
@nb.jit(nopython=True)
def logi_nb(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
out = np.empty_like(final_list)
for n in range(r.shape[0]):
out[:,n] = np.sort(final_list[:,n])
return out
def logi(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
y_orig = logi_orig(0.5, r)
y = logi(0.5, r)
y_nb = logi_nb(0.5, r)
print np.allclose(y, y_orig)
print np.allclose(y_nb, y_orig)
for i in range(N_SPLITS):
plt.plot(r, y[i], c='k', lw=0.1)
OSX(2014 MBP)与Numba v0.34.0的时间安排:
%timeit logi_orig(0.5, r)
%timeit logi(0.5, r)
%timeit logi_nb(0.5, r)
10 loops, best of 3: 171 ms per loop
10 loops, best of 3: 168 ms per loop
10 loops, best of 3: 77 ms per loop