我有一个很大的2D NumPy数组,比方说5M行和10列。我想根据使用Numba @jitclass
实现的一些有状态逻辑构建更多列。假设要创建50个这样的新列。我们的想法是在Numba @jit
函数中迭代10列的所有行,并且对于每一行,应用我的每个50“过滤器”以生成每个新的单元格。所以:
Source1..Source10 Derived1..Derived50
[array of 10 inputs] [array of 50 outputs]
... 5 million rows like this ...
问题是,我无法将“过滤器”的列表或元组传递给@jit(nopython=True)
函数,因为它们不是同质的:
@numba.jit(nopython=True)
def calc_derived(source, derived, filters):
for srcidx, src in enumerate(source):
for filtidx, filt in enumerate(filters): # doesn't work
derived[srcidx,filtidx] = filt.transform(src)
上述方法无效,因为filters
是一堆不同的类。据我所知,即使从一个共同的基类派生它们也不够好。
我有可能交换循环的顺序,并将循环覆盖在@jit
函数之外的50个过滤器上,但这意味着整个源数据集将被加载50次而不是曾经,这非常浪费。
您是否有技术可以解决Numba的“同质列表”要求?
答案 0 :(得分:2)
您最初询问是否使用循环遍历行的单个函数执行此操作,并将过滤器列表应用于每一行。这种方法的一个挑战是numba需要知道或能够推断每个函数的输入/输出类型。我不知道在这种情况下满足numba要求的方法(这并不是说不存在)。如果有办法做到这一点,它可能是一个更好的解决方案(而且我想知道它是什么)。
另一种方法是将遍历行的代码移动到过滤器本身。因为过滤器是numba函数,所以这应该保持速度。应用过滤器的功能将更长时间使用numba;它只是循环遍历过滤器列表。但是,由于过滤器的数量相对于数据矩阵的大小而言很小,因此希望这不会对影响速度造成太大的影响。因为此功能不再使用numba,所以异构列表'问题将不再是一个问题。
这种方法在我测试时起作用(nopython模式很好)。在测试用例中,作为numba函数实现的过滤器比作为类方法实现的过滤器快10-18倍(即使类被实现为numba jitclasses;不确定那里发生了什么)。为了获得一点模块性,可以将过滤器构造为闭包,以便可以使用不同的参数定义类似的过滤器。
例如,这里是计算权力总和的过滤器。给定矩阵x
,过滤器在x列上运行,为每行提供输出。它返回一个向量v
,其中v[i] = sum(x[i, :] ** power)
# filter constructor
def sumpow(power):
@numba.jit(nopython=True)
def run_filter(x):
(nrows, ncols) = x.shape
result = np.zeros(nrows)
for i in range(nrows):
for j in range(ncols):
result[i] += x[i,j] ** power
return result
return run_filter
# define filters
sum1 = sumpow(1) # sum of elements
sum2 = sumpow(2) # sum of elements squared
# apply a single filter
v = sum2(x)
应用多个过滤器的功能如下所示。每个过滤器的输出都堆叠在输出列中。
def apply_filters(x, filters):
result = np.empty((x.shape[0], len(filters)))
for (i, f) in enumerate(filters):
result[:, i] = f(x)
return result
y = apply_filters(x, [sum1, sum2])
计时结果
sum2
过滤器,在列表中重复20次:[sum2, sum2, ...]
- Numba功能过滤器(如上图所示):2.25s
- Numba jitclass过滤器:28.3s
- Pure NumPy(使用矢量化操作,无循环):8.64s
我想Numba可能会因为更复杂的过滤器而获得NumPy。
答案 1 :(得分:0)
要获得同类列表,您可以构建所有过滤器的transform
函数列表。在这种情况下,所有列表元素都将具有类型method
。
# filters = list of filters
transforms = [x.transform for x in filters]
然后将transforms
传递给calc_derived()
而不是filters
。
编辑: 在我的系统上,看起来像numba会接受这个,但仅限于nopython = False