如何编译具有可变输入类型的数字函数?

时间:2019-04-23 07:48:56

标签: python random signature optional-parameters numba

说我有一个可以接受intNone类型作为输入参数的函数

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

我希望函数简单地返回一个正态分布的随机数。如果我想要可重复的结果,则种子应该为int

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

如果我需要随机数,则应将seed保留为None。但是,如果我不传递参数(因此种子默认为None)或显式传递seed=None,则numba会引发TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

在这种情况下,如何编写函数,仍然声明签名并使用nopython模式?

我的numba版本是0.43.1

1 个答案:

答案 0 :(得分:2)

第一个问题是nopython模式下的numba仅接受(自版本0.43.1起)np.random.seed: with an integer argument only

因此,很遗憾,您无法通过None


第二个问题是(据我所知)没有一个“单一”签名告诉numba如何处理缺失值,但是您可以使用两个签名(是的,非常冗长):

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()

仅简要说明一下签名的两个部分:

  • 如果省略了参数,nb.types.float64(nb.types.misc.Omitted(None))告诉numba使用None作为默认类型
  • nb.types.float64(nb.types.int64)是期​​望整数的签名。

我个人不会指定签名,而只是让numba找出来。显式签名在numba中很少值得使用,更多时候,不是,它们会导致代码变慢和灵活性降低。