Numba无法识别该字符串。如何更正以下代码?谢谢!
@nb.jit(nb.float64(nb.float64[:], nb.char[:]), nopython=True, cache=True)
def func(x, y='cont'):
"""
:param x: is np.array, x.shape=(n,)
:param y: is a string,
:return: a np.array of same shape as x
"""
return result
答案 0 :(得分:1)
以下适用于Numba 0.44的作品:
import numpy as np
import numba as nb
from numba import types
@nb.jit(nb.float64[:](nb.float64[:], types.unicode_type), nopython=True, cache=True)
def func(x, y='cont'):
"""
:param x: is np.array, x.shape=(n,)
:param y: is a string,
:return: a np.array of same shape as x
"""
print(y)
return x
但是,如果尝试在未指定值func
的情况下运行y
,则会出错,因为在签名中,您需要第二个参数。我试图弄清楚如何处理可选参数(在types.Omitted
处查看),但还不太清楚。我可能会考虑不指定签名,而让numba进行正确的类型推断:
@nb.jit(nopython=True, cache=True)
def func2(x, y='cont'):
"""
:param x: is np.array, x.shape=(n,)
:param y: is a string,
:return: a np.array of same shape as x
"""
print(y)
return x