Python / Numba - 自定义类对象作为输入类型

时间:2018-05-22 12:22:02

标签: python python-3.x numpy numba

我从numba开始,我的第一个目标是尝试使用嵌套循环加速一个不那么复杂的函数。

鉴于以下课程:

class TestA:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def get_mult(self):
        return self.a * self.b

和包含类numpy ndarray个对象的TestA。尺寸(N,)其中N的长度通常约为300万。

现在给出以下功能:

def test_no_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in range(container_length):
        for j in range(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

我试图玩numba以使其与上述功能一起使用但是我似乎无法使用nopython=True标志,如果它是设置为false,然后运行时高于no-jit函数。

以下是我尝试jit函数的最新尝试(也使用nb.prange):

@nb.jit(nopython=False, parallel=True)
def test_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in nb.prange(container_length):
        for j in nb.prange(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

我试图搜索一下,但我似乎无法找到如何在签名中定义自定义类的教程,以及如何加速该类函数并使其运行在GPU和可能(有关此事的任何信息将被高度赞赏)使其与cuda库一起运行 - 这些库已安装并可供使用(之前与tensorflow一起使用)

1 个答案:

答案 0 :(得分:1)

numba文档提供了一个创建自定义类型的示例,即使在nopython模式下也是如此:https://numba.pydata.org/numba-doc/latest/extending/interval-example.html

在您的情况下,除非这是您实际想要的精简版本,否则似乎最简单的方法是重用现有类型。此外,长度为3M的对象阵列的构建速度将很慢,并且会产生碎片化的内存(因为对象没有存储在连续的块中)。

如何使用记录数组解决问题的示例:

x_dt = np.dtype([('a', np.float64),
                 ('b', np.float64)])
n = 30000
buf = np.arange(n*2).reshape((n, 2)).astype(np.float64)
vec3 = np.recarray(n, dtype=x_dt, buf=buf) 

@numba.njit
def mult(a):
    return a.a * a.b

@numba.jit(nopython=True, parallel=True)
def sum_of_prod(vector):
    sum = 0
    vector_len = len(vector)
    for i in numba.prange(vector_len):
        for j in numba.prange(i + 1, vector_len):
            sum += mult(vector[i]) + mult(vector[j])
    return sum

sum_of_prod(vec3)

FWIW,我不是专家。我在搜索如何在numba中实现非数字内容的自定义类型时发现了这个问题。在您的情况下,由于这是高度数字化的,因此我认为自定义类型可能会过大。