我想使用@jit
或@autojit
来加速我的python代码,在此解释:http://nbviewer.ipython.org/gist/harrism/f5707335f40af9463c43
然而,该页面上的示例是针对纯python函数的,而我的函数是在一个类中,并且基于一些更多的搜索,似乎为了使用类函数进行这项工作,我必须提供显式的签名。功能
之前我没有使用过签名,但我现在明白了如何将它们用于具有简单参数的函数。但我对如何为复杂的参数(如2D数组)编写它们感到困惑。
下面是我的功能,我需要一个明确的签名。
我真的不确定在@void
之后写什么......
""" Function: train
Input parameters:
#X = shape: [n_samples, n_features]
#y = classes corresponding to X , y's shape: [n_samples]
#H = int, number of boosting rounds
Returns: None
Trains the model based on the training data and true classes
"""
#@autojit
#@void
def train(self, X, y, H):
# function code below
# do lots of stuff...
鉴于我的参数类型,我试过这个:
@void(float_[:,:],int_[:],int_)
但出现以下错误:
Traceback (most recent call last):
File "C:\Users\app\Documents\Python Scripts\gbc_carclassify.py", line 18, in <module>
import gentleboost_c_class as gbc
File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 20, in <module>
@jit
File "C:\Users\app\Anaconda\lib\site-packages\numba\decorators.py", line 272, in jit
return jit_extension_class(cls, kws, env)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\entrypoints.py", line 20, in jit_extension_class
return jitclass.create_extension(env, py_class, translator_kwargs)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\jitclass.py", line 98, in create_extension
ext_type = typesystem.jit_exttype(py_class)
File "C:\Users\app\Anaconda\lib\site-packages\numba\typesystem\types.py", line 55, in __call__
return type.__call__(self, *args)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\types\extensiontype.py", line 37, in __init__
assert isinstance(py_class, type), ("Must be a new-style class "
AssertionError: Must be a new-style class (inherit from 'object')
我已经更改了课程的开头以添加(object)
,所以它现在看起来像这样:
import numba
from numba import jit, autojit, int_, void, float_
@jit
class GentleBoostC(object):
def __init__(self):
# init function
@void(float_[:,:],int_[:],int_)
def train(self, X, y, H): # this is the function I want to speed up
# do stuff
但现在我收到了这个错误:
C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\validators.py:74: UserWarning: Constructor for class 'GentleBoostC' has no signature, assuming arguments have type 'object'
ext_type.py_class.__name__)
Traceback (most recent call last):
File "C:\Users\app\Documents\Python Scripts\gbc_carclassify.py", line 18, in <module>
import gentleboost_c_class as gbc
File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 21, in <module>
class GentleBoostC(object):
File "C:\Users\app\Anaconda\lib\site-packages\numba\decorators.py", line 272, in jit
return jit_extension_class(cls, kws, env)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\entrypoints.py", line 20, in jit_extension_class
return jitclass.create_extension(env, py_class, translator_kwargs)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\jitclass.py", line 110, in create_extension
extension_compiler.infer()
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 112, in infer
self.type_infer_methods()
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 145, in type_infer_methods
self.type_infer_method(method)
File "C:\Users\app\Anaconda\lib\site-packages\numba\exttypes\compileclass.py", line 121, in type_infer_method
**self.flags)
File "C:\Users\app\Anaconda\lib\site-packages\numba\pipeline.py", line 133, in compile2
func_ast = functions._get_ast(func)
File "C:\Users\app\Anaconda\lib\site-packages\numba\functions.py", line 89, in _get_ast
ast.PyCF_ONLY_AST | flags, True)
File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class.py", line 1
def train(self, X, y, H):
^
IndentationError: unexpected indent
我认为我没有缩进错误...在将object
添加到类之前,我对此完全相同的代码没有任何问题。
答案 0 :(得分:1)
您可以在数据类型上使用切片语法来表示数组。所以你的例子可能看起来像:
from numba import void, int_, float_, jit
...
@jit
class YourClass(object):
...
@void(float_[:, :], int_[:], int_)
def train(self, X, y, H):
# X is typed as a 2D float array and y as a 1D int array.
pass