如何使用numba加速快速排序?

时间:2015-03-22 22:01:09

标签: python numpy numba

我正在尝试使用Python中的numba实现快速排序算法。

它似乎比numpy sort函数慢很多。

我怎么能改善它?我的代码在这里:

import numba as nb

@nb.autojit
def quick_sort(list_):
    """
    Iterative version of quick sort
    """
    #temp_stack = []
    #temp_stack.append((left,right))

    max_depth = 1000

    left = 0
    right = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:
            #temp_stack.append((left,piv-1))

            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        #If items in the right of the pivot push them to the stack
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j

我的测试代码在这里:

    x = np.random.random_integers(0,1000,1000000)
    y = x.copy()

    quick_sort( y )

    z = np.sort(x)

    np.testing.assert_array_equal( z, y )

    y = x.copy()
    with Timer( 'nb' ):
        numba_fns.quick_sort( y )

    with Timer( 'np' ):
        x = np.sort(x) 

更新:

我重写了函数来强制代码的循环部分在nopython模式下运行。 while循环似乎没有导致nopython失败。但是,我没有获得任何性能提升:

@nb.autojit
def quick_sort2(list_):
    """
    Iterative version of quick sort
    """

    max_depth = 1000

    left        = 0
    right       = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    return _quick_sort2( list_, a_temp_stack, left, right )

@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):

    i_stack_pos = 1
    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:            
            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j

3 个答案:

答案 0 :(得分:3)

一个可能有所帮助的小建议(但正如你在评论的评论中正确地告诉你的那样,你将很难击败纯粹的C实现):

你想确保大部分内容都是在&#34; nopython&#34;模式(@jit(nopython=True))。在功能开始之前添加它,看看它在哪里断开。同时在您的功能上调用inspect_types(),看看它是否能够正确识别它们。

你的代码中有一点可能强迫它进入对象模式(与nopython模式相反)是一个numpy数组的分配。虽然numba可以在nopython模式下单独编译循环,但我不知道它是否可以为while循环执行此操作。致电inspect_types会告诉您。

我通常用于创建numpy数组的workround,同时确保其余部分处于nopython模式,这是创建一个包装函数。

@nb.jit(nopython=True) # make sure it can be done in nopython mode
def _quick_sort_impl(list_,output_array):
   ...most of your code goes here...

@nb.jit
def quick_sort(list_):
   # this code won't compile in nopython mode, but it's
   # short and isolated
   max_depth = 1000
   a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
   _quick_sort_impl(list_,a_temp_stack)

答案 1 :(得分:3)

一般情况下,如果您不强制使用nopython模式,则很有可能无法提升效果。引自the docs about nopython mode

  

[nopython]模式生成最高性能代码,但要求可以推断出函数中所有值的本机类型,并且不分配新对象

因此,您的np.ndarray调用会触发对象模式,从而减慢代码速度。尝试从函数外部分配工作数组,如:

def quick_sort(list_):

    max_depth = 1000
    temp_stack_ = np.array( ( max_depth, 2), dtype=np.int32 )

    _quick_sort(list_, temp_stack_)

...

@numba.jit(nopython=True)
def _quick_sort(list_, temp_stack_):
    ...

答案 2 :(得分:0)

对于它的价值,numba已经实现了通用sorted函数和numpy-array .sort()方法(我认为)版本0.22。万岁!

http://numba.pydata.org/numba-doc/dev/reference/pysupported.html#built-in-functions http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#other-methods