使用Numba在每一行上应用多个函数

时间:2016-05-27 01:54:55

标签: python numpy optimization numba

我有一个很大的2D NumPy数组,比方说5M行和10列。我想根据使用Numba @jitclass实现的一些有状态逻辑构建更多列。假设要创建50个这样的新列。我们的想法是在Numba @jit函数中迭代10列的所有行,并且对于每一行,应用我的每个50“过滤器”以生成每个新的单元格。所以:

 Source1..Source10    Derived1..Derived50
[array of 10 inputs] [array of 50 outputs]
     ... 5 million rows like this ...

问题是,我无法将“过滤器”的列表或元组传递给@jit(nopython=True)函数,因为它们不是同质的:

@numba.jit(nopython=True)
def calc_derived(source, derived, filters):
    for srcidx, src in enumerate(source):
        for filtidx, filt in enumerate(filters): # doesn't work
            derived[srcidx,filtidx] = filt.transform(src)

上述方法无效,因为filters是一堆不同的类。据我所知,即使从一个共同的基类派生它们也不够好。

我有可能交换循环的顺序,并将循环覆盖在@jit函数之外的50个过滤器上,但这意味着整个源数据集将被加载50次而不是曾经,这非常浪费。

您是否有技术可以解决Numba的“同质列表”要求?

2 个答案:

答案 0 :(得分:2)

您最初询问是否使用循环遍历行的单个函数执行此操作,并将过滤器列表应用于每一行。这种方法的一个挑战是numba需要知道或能够推断每个函数的输入/输出类型。我不知道在这种情况下满足numba要求的方法(这并不是说不存在)。如果有办法做到这一点,它可能是一个更好的解决方案(而且我想知道它是什么)。

另一种方法是将遍历行的代码移动到过滤器本身。因为过滤器是numba函数,所以这应该保持速度。应用过滤器的功能将更长时间使用numba;它只是循环遍历过滤器列表。但是,由于过滤器的数量相对于数据矩阵的大小而言很小,因此希望这不会对影响速度造成太大的影响。因为此功能不再使用numba,所以异构列表'问题将不再是一个问题。

这种方法在我测试时起作用(nopython模式很好)。在测试用例中,作为numba函数实现的过滤器比作为类方法实现的过滤器快10-18倍(即使类被实现为numba jitclasses;不确定那里发生了什么)。为了获得一点模块性,可以将过滤器构造为闭包,以便可以使用不同的参数定义类似的过滤器。

例如,这里是计算权力总和的过滤器。给定矩阵x,过滤器在x列上运行,为每行提供输出。它返回一个向量v,其中v[i] = sum(x[i, :] ** power)

# filter constructor
def sumpow(power):

    @numba.jit(nopython=True)
    def run_filter(x):
        (nrows, ncols) = x.shape
        result = np.zeros(nrows)
        for i in range(nrows):
            for j in range(ncols):
                result[i] += x[i,j] ** power
        return result

    return run_filter

# define filters
sum1 = sumpow(1) # sum of elements
sum2 = sumpow(2) # sum of elements squared

# apply a single filter
v = sum2(x)

应用多个过滤器的功能如下所示。每个过滤器的输出都堆叠在输出列中。

def apply_filters(x, filters):

    result = np.empty((x.shape[0], len(filters)))

    for (i, f) in enumerate(filters):
        result[:, i] = f(x)

    return result


y = apply_filters(x, [sum1, sum2])

计时结果

  • 数据矩阵:从标准正态分布,float64,500万行×10列中抽取的随机条目。所有方法都使用相同的矩阵进行测试。
  • 过滤器:上面sum2过滤器,在列表中重复20次:[sum2, sum2, ...]
  • 使用IPython&times%timeit函数计时,最好3次
  • 所有方法的数字输出同意
  
      
  • Numba功能过滤器(如上图所示):2.25s
  •   
  • Numba jitclass过滤器:28.3s
  •   
  • Pure NumPy(使用矢量化操作,无循环):8.64s
  •   

我想Numba可能会因为更复杂的过滤器而获得NumPy。

答案 1 :(得分:0)

要获得同类列表,您可以构建所有过滤器的transform函数列表。在这种情况下,所有列表元素都将具有类型method

# filters = list of filters
transforms = [x.transform for x in filters]

然后将transforms传递给calc_derived()而不是filters

编辑: 在我的系统上,看起来像numba会接受这个,但仅限于nopython = False