Numba优化功能无法提高速度

时间:2018-08-06 20:28:11

标签: python performance numpy optimization numba

我写了一些代码,计算出不同卡车对300英尺长的桥梁施加的弯矩。卡车数据包含在两个列表中:ax_listsp_list,分别是车轴重量和车轴间距。

代码没有太多内容,但是,这需要针对数百万种不同类型的卡车进行重复,并且我正在尝试优化我的代码,这在考虑实际数据大小集时会花费很长时间。 >

我尝试使用Numba来查看是否可以提高速度,但是无论是否为每个函数添加Numba @jit装饰器,它都没有改变执行时间。我在这里做错了什么?任何帮助都将受到欢迎!我还包括以下代码,可为1000条记录生成代表性的伪数据:

import random
from numba import jit
import numpy as np
from __future__ import division

#Generate Random Data Set

ax_list=[]
sp_list=[]

for i in xrange(1000):
    n = random.randint(3,10)
    ax = []
    sp = [0]
    for i in xrange(n):
        a = round(random.uniform(8,32),1)
        ax.append(a)
    for i in xrange(n-1):
        s = round(random.uniform(4,30), 1)
        sp.append(s)
    ax_list.append(ax)
    sp_list.append(sp)

#Input Parameters
L=300
step_size=4
cstep_size=4
moment_list=[]

@jit
#Simple moment function
def Moment(x):
    if x<L/2.0:
        return 0.5*x
    else:
        return 0.5*(L-x)

#Attempt to vectorize the Moment function, hoping for speed gains
vectMoment = np.vectorize(Moment,otypes=[np.float],cache=False)

@jit
#Truck movement function that uses the vectorized Moment function above
def SimpleSpanMoment(axles, spacings, step_size):
    travel = L + sum(spacings)
    spacings=list(spacings)
    maxmoment = 0
    axle_coords =(0-np.cumsum(spacings))
    while np.min(axle_coords) < L:
        axle_coords = axle_coords + step_size
        moment_inf = np.where((axle_coords >= 0) & (axle_coords <=L), vectMoment(axle_coords), 0)
        moment = sum(moment_inf * axles)
        if maxmoment < moment:
            maxmoment = moment
    return maxmoment

然后将循环运行1000次:

%%timeit
for i in xrange(len(ax_list)):
    moment_list.append(np.around(SimpleSpanMoment(ax_list[i], sp_list[i], step_size),1))

产量:

1 loop, best of 3: 2 s per loop

我还尝试在jit装饰器中声明类型,但结果仍然没有变化。

@jit('f8(f8)')@jit('f8(f8[:],f8[:],f8)')分别用于这两个功能。

1 个答案:

答案 0 :(得分:4)

基本问题是,当您使用nb.jit并且遇到无法编译为本机代码的问题时,它会使用object mode来代替,这可能会很慢。如果您将nopython=True指定为大多数numba装饰器/函数的参数,则可以看到此信息。如果numba无法显式键入变量或不知道如何翻译函数,则将收到错误消息。我相信这是一个与原始功能产生相同结果的版本。在我的机器上,您的代码大约需要2.7秒才能运行。完全在nopython模式下运行的经过优化的以下版本大约需要50毫秒(约50倍的加速):

@nb.vectorize(nopython=True)
#Simple moment function
def vectMoment2(x):
    if x<L/2.0:
        return 0.5*x
    else:
        return 0.5*(L-x)

@nb.jit(nopython=True)
#Truck movement function that uses the vectorized Moment function above
def SimpleSpanMoment2(axles, spacings, step_size):
    travel = L + np.sum(spacings)
    maxmoment = 0
    axle_coords = -np.cumsum(spacings)

    moment_inf = np.empty_like(axles)
    while np.min(axle_coords) < L:
        axle_coords = axle_coords + step_size
        y = vectMoment2(axle_coords)

        for k in range(y.shape[0]):
            if axle_coords[k] >=0 and axle_coords[k] <= L:
                moment_inf[k] = y[k]
            else:
                moment_inf[k] = 0.0

        moment = np.sum(moment_inf * axles)
        if maxmoment < moment:
            maxmoment = moment
    return maxmoment

,然后通过以下方式计时:

%%timeit
for i in xrange(len(ax_list)):
    moment_list2.append(np.around(SimpleSpanMoment2(np.array(ax_list[i]), np.array(sp_list[i]), step_size),1))

看看文档,了解Numba在nopython模式下支持什么:

请注意,您可以在np.where模式下在numba函数内使用nopython,但第3个参数必须是数组(例如np.zeros_like(moment_inf))而不是整数。我发现它比我上面显式循环遍历数组的函数的编码方式慢大约2倍。