我怎样才能将数组附加到numba中的列表?

时间:2017-07-12 00:56:15

标签: python numba

我认为我的代码无效,因为我有一个数组列表。是否有不同的方法可以将final_list数组列表构造为矩阵,以便numba接受它?

import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000

@nb.jit(nopython=True)
def logi(x0, r):
  x = x0
  for n in range(30000):
     x = x * r * (1-x)
  final_list = [x]
  for n in range(N_SPLITS):
     final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
  return np.sort(final_list, axis=0)

r = np.arange(2.4, 4., .001)
for i in range(N_SPLITS):
   plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
plt.savefig('bifig.pdf')




  File "logi.py", line 18, in <module>
    plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
  File "/usr/local/lib/python2.7/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
numba.errors.TypingError: Caused By:
Traceback (most recent call last):
  File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 235, in run
    stage()
  File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 449, in stage_nopython_frontend
    self.locals)
  File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 805, in type_inference_stage
    infer.propagate()
  File "/usr/local/lib/python2.7/site-packages/numba/typeinfer.py", line 767, in propagate
    raise errors[0]
TypingError: Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
 * parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)

Failed at nopython (nopython frontend)
Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
 * parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)

1 个答案:

答案 0 :(得分:2)

您的代码存在许多问题导致numba jit-compiler出现问题:

    带参数的
  • np.sort无效,也无法在2D数组上使用它 (见:numpy supported features

  • x从浮点数更改为数组。 Numba要求整个函数的类型一致性

下面是一个numba函数,它以nopython模式编译并产生相同的结果。基本上我预先分配存储阵列,因为事先知道大小然后按列排序。不幸的是numba并没有真正好的排序实现,所以你不会获得非常大的加速。您可以进行其他性能调整更改。另请注意,在绘图部分的每个循环中调用logi然后拉出单个值是没有意义的。只需计算一次数组,然后选出所需的值。

import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000

%matplotlib inline

def logi_orig(x0, r):
    x = x0
    for n in range(30000):
        x = x * r * (1-x)
    final_list = [x]
    for n in range(N_SPLITS):
        final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
    return np.sort(final_list, axis=0)

@nb.jit(nopython=True)
def logi_nb(x0, r):
    x = np.full_like(r, x0)
    for n in range(30000):
        x = x * r * (1-x)
    final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
    final_list[0,:] = x
    for n in range(1, N_SPLITS + 1):
        final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])

    out = np.empty_like(final_list)
    for n in range(r.shape[0]):
        out[:,n] = np.sort(final_list[:,n])

    return out

def logi(x0, r):
    x = np.full_like(r, x0)
    for n in range(30000):
        x = x * r * (1-x)
    final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
    final_list[0,:] = x
    for n in range(1, N_SPLITS + 1):
        final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])

    return np.sort(final_list, axis=0)

r = np.arange(2.4, 4., .001)

y_orig = logi_orig(0.5, r)
y = logi(0.5, r)
y_nb = logi_nb(0.5, r)

print np.allclose(y, y_orig)
print np.allclose(y_nb, y_orig)

for i in range(N_SPLITS):
    plt.plot(r, y[i], c='k', lw=0.1)

OSX(2014 MBP)与Numba v0.34.0的时间安排:

%timeit logi_orig(0.5, r)
%timeit logi(0.5, r)
%timeit logi_nb(0.5, r)

10 loops, best of 3: 171 ms per loop
10 loops, best of 3: 168 ms per loop
10 loops, best of 3: 77 ms per loop