程序化嵌套numba.cuda函数调用

时间:2018-10-14 22:06:07

标签: python cuda numba

Numba和CUDA菜鸟在这里。我希望能够有一个numba.cuda函数以编程方式从设备中调用另一个函数,而不必将任何数据传递回主机。例如,给定设置

from numba import cuda

@cuda.jit('int32(int32)', device=True)
def a(x):
    return x+1

@cuda.jit('int32(int32)', device=True)
def b(x):
    return 2*x

我希望能够定义一个合成内核函数,例如

@cuda.jit('void(int32, __device__, int32)')
def b_comp(x, inner, result):
    y = inner(x)
    result = b(y)

并成功获得

b_comp(1, a, result)
assert result == 4

理想情况下,我希望b_comp在编译后接受各种函数参数[例如在上述调用之后,仍然可以接受b_comp(1, b, result)],但是对于函数参数在编译时固定的解决方案仍然适用于我。

从我阅读的内容来看,CUDA似乎支持传递函数指针。 This post暗示numba.cuda没有这种支持,但该帖子并不令人信服,并且成立已有1年。 supported Python in numba.cuda的页面未提及函数指针支持。但是它链接到supported Python in numba页面,这清楚表明numba.jit() 确实支持作为参数,尽管它们在编译时就固定了。如果numba.cuda.jit()像我上面说的那样做,那就可以了。在这种情况下,当为comp指定签名时,应该如何声明变量类型?还是可以使用numba.cuda.autojit()

如果numba不支持任何此类直接方法,那么元编程是否是一个合理的选择?例如。一旦知道inner函数,我的脚本就可以创建一个包含组成这些特定函数的python函数的新脚本,然后应用numba.cuda.jit(),然后导入结果。似乎令人费解,但这是我能想到的唯一其他基于numba的选项。

如果numba根本无法解决问题,或者至少没有严重的麻烦,我会很高兴给出一些细节的答案,再加上“切换到PyCuda”这样的提示。

1 个答案:

答案 0 :(得分:2)

这对我有用:

  1. 最初不使用cuda.jit装饰我的函数,以便它们仍然具有__name__属性
  2. 获取__name__属性
  3. 现在通过直接调用装饰器将cuda.jit应用于我的函数
  4. 为字符串中的composition函数创建python,并将其传递给exec

确切的代码:

from numba import cuda
import numpy as np


def a(x):
    return x+1

def b(x):
    return 2*x


# Here, pretend we've been passed the inner function and the outer function as arguments
inner_fun = a
outer_fun = b

# And pretend we have noooooo idea what functions these guys actually point to
inner_name = inner_fun.__name__
outer_name = outer_fun.__name__

# Now manually apply the decorator
a = cuda.jit('int32(int32)', device=True)(a)
b = cuda.jit('int32(int32)', device=True)(b)

# Now construct the definition string for the composition function, and exec it.
exec_string = '@cuda.jit(\'void(int32, int32[:])\')\n' \
              'def custom_comp(x, out_array):\n' \
              '    out_array[0]=' + outer_name + '(' + inner_name + '(x))\n'

exec(exec_string)

out_array = np.array([-1])
custom_comp(1, out_array)
print(out_array)

按预期,输出为

[4]