我一直试图从numba文档中全天推断出类型的设置方式。我已经有了一些方法,但是现在我想创建一个返回一维数组和一个二维数组的函数,并采取一堆args,我很难得到任何进一步:
@jit
class name(object)
@double[:,:], double[:](double[:], double, double, int64)
def solve(self, u0, a, b, n):
self.t = linspace(a, b, n+1)
dt = abs((b-a)/float(n))
u = zeros(n+1, len([u0]))
u[0] = u0
u = advance(u, t, n, dt)
return u.transpose(), t.transpose()
以上引发了这些例外情况:
Traceback (most recent call last):
File "/home/marius/dev/python/inf1100/test_ODE.py", line 2, in <module>
from DE import *
File "/home/marius/dev/python/inf1100/DE.py", line 13
@double[:,:], double[:](double[:], double, double, int64)
^
SyntaxError: invalid syntax
如果你能告诉我出了什么问题会很好,但是如果你能推荐一个一劳永逸地严格解释这些语法的文件会更好。
感谢您的时间。
亲切的问候, 的Marius
答案 0 :(得分:2)
这是一个返回元组的方法的简单版本。这适用于我在OS X上使用Numba 0.11.1:
import numba
import numpy as np
@numba.jit
class name(object):
@numba.object_(numba.double[:], numba.double)
def solve(self, x, a):
y = np.empty(x.shape[0], dtype=np.float64)
z = np.empty(x.shape[0], dtype=np.float64)
for k in xrange(x.shape[0]):
y[k] = x[k] * a
z[k] = x[k] + a
return y, z
然后使用它:
C = name()
a, b = C.solve(np.arange(5, dtype=np.float64), 3.0)
a
和b
的位置:
In [24]: a
Out[24]:
array([ 0., 3., 6., 9., 12.])
In [22]: b
Out[22]:
array([ 3., 4., 5., 6., 7.])