numba中的性能嵌套循环

时间:2016-12-09 01:11:00

标签: python numpy numba

出于性能原因,我已经开始使用除NumPy之外的Numba。我的Numba算法正在运行,但我觉得它应该更快。有一点让它变慢。以下是代码段:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1

在我看来,if命令正在减慢速度。有没有更好的办法? (我在此尝试实现的内容与之前发布的问题相关:Count possibilites for single crossoversws是一个大小为(gn, l)的NumPy数组,其中包含01 }的

2 个答案:

答案 0 :(得分:2)

鉴于想要确保所有项目相等的逻辑,您可以利用如果任何不相等的事实,您可以短路(即停止比较)计算。我稍微修改了你的原始函数,以便(1)你不重复相同的比较两次,并且(2)对所有嵌套循环求和,所以有一个可以比较的回报:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                        ysum += 1

    return ysum


@nb.njit
def rfunc2(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):

                    incr_y = True
                    for j in range(i):
                        if ws[x1,j] != ws[x2,j]:
                            incr_y = False
                            break

                    if incr_y is True:
                        for j in range(i,l):
                            if ws[x1,j] != ws[x3,j]:
                                incr_y = False
                                break
                    if incr_y is True:
                        y += 1
                        ysum += 1
    return ysum

我不知道完整的功能是什么样的,但希望这可以帮助您开始正确的道路。

现在有一些时间:

l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:

%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop


%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop

In [27]: rfunc1(ws, a , l)
Out[27]: 131919

In [30]: rfunc2(ws, a , l)
Out[30]: 131919

这可以让你加速50倍。

答案 1 :(得分:2)

而不仅仅是感受"您的瓶颈在哪里,为什么不配置文件您的代码并找到完全在哪里?

分析的第一个目标是测试代表性系统以识别什么是慢速(或使用太多RAM,或导致过多的磁盘I / O或网络I / O)。

分析通常会增加开销(典型的10x到100x减速),并且您仍然希望代码的使用方式与实际情况相似。提取测试用例并隔离您需要测试的系统部分。优选地,它已经被编写为已经在其自己的模块集中。

基本技巧包括IPython中的%timeit魔法,time.time(),timing decorator(参见下面的示例)。您可以使用这些技术来理解语句和函数的行为。

然后你有cProfile这将为你提供问题的高级视图,这样你就可以将注意力转移到关键功能上。

接下来,查看line_profiler,,它将逐行分析您选择的功能。结果将包括每行调用的次数以及每行所花费的时间百分比。这正是您需要了解的信息,以了解运行缓慢的原因以及原因。

perf stat可帮助您了解最终在CPU上执行的指令数以及CPU缓存的使用效率。这允许对矩阵运算进行高级调整。

heapy可以跟踪Python内存中的所有对象。这对于追捕奇怪的内存泄漏非常有用。如果您正在使用长时间运行的系统, 然后dowser会让您感兴趣:它允许您通过Web浏览器界面在长时间运行的过程中内省活动对象。

为了帮助您了解RAM使用率高的原因,请查看memory_profiler.这对于跟踪标记图表上的RAM使用情况特别有用,因此您可以向同事(或您自己)解释为什么某些功能会使用RAM超出预期。

示例:定义装饰器以自动执行定时测量

from functools import wraps

def timefn(fn):
    @wraps(fn)
    def measure_time(*args, **kwargs):
        t1 = time.time()
        result = fn(*args, **kwargs)
        t2 = time.time()
        print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds")
        return result
    return measure_time

@timefn
def your_func(var1, var2):
    ...

有关更多信息,我建议阅读上述内容的High performance Python(Micha Gorelick; Ian Ozsvald)。