我正在尝试使用Python中的numba实现快速排序算法。
它似乎比numpy sort函数慢很多。
我怎么能改善它?我的代码在这里:
import numba as nb
@nb.autojit
def quick_sort(list_):
"""
Iterative version of quick sort
"""
#temp_stack = []
#temp_stack.append((left,right))
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
#temp_stack.append((left,piv-1))
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
#If items in the right of the pivot push them to the stack
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
我的测试代码在这里:
x = np.random.random_integers(0,1000,1000000)
y = x.copy()
quick_sort( y )
z = np.sort(x)
np.testing.assert_array_equal( z, y )
y = x.copy()
with Timer( 'nb' ):
numba_fns.quick_sort( y )
with Timer( 'np' ):
x = np.sort(x)
更新:
我重写了函数来强制代码的循环部分在nopython模式下运行。 while循环似乎没有导致nopython失败。但是,我没有获得任何性能提升:
@nb.autojit
def quick_sort2(list_):
"""
Iterative version of quick sort
"""
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
return _quick_sort2( list_, a_temp_stack, left, right )
@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):
i_stack_pos = 1
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
答案 0 :(得分:3)
一个可能有所帮助的小建议(但正如你在评论的评论中正确地告诉你的那样,你将很难击败纯粹的C实现):
你想确保大部分内容都是在&#34; nopython&#34;模式(@jit(nopython=True)
)。在功能开始之前添加它,看看它在哪里断开。同时在您的功能上调用inspect_types()
,看看它是否能够正确识别它们。
你的代码中有一点可能强迫它进入对象模式(与nopython模式相反)是一个numpy数组的分配。虽然numba可以在nopython模式下单独编译循环,但我不知道它是否可以为while循环执行此操作。致电inspect_types
会告诉您。
我通常用于创建numpy数组的workround,同时确保其余部分处于nopython模式,这是创建一个包装函数。
@nb.jit(nopython=True) # make sure it can be done in nopython mode
def _quick_sort_impl(list_,output_array):
...most of your code goes here...
@nb.jit
def quick_sort(list_):
# this code won't compile in nopython mode, but it's
# short and isolated
max_depth = 1000
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
_quick_sort_impl(list_,a_temp_stack)
答案 1 :(得分:3)
一般情况下,如果您不强制使用nopython
模式,则很有可能无法提升效果。引自the docs about nopython
mode:
[
nopython
]模式生成最高性能代码,但要求可以推断出函数中所有值的本机类型,并且不分配新对象
因此,您的np.ndarray
调用会触发对象模式,从而减慢代码速度。尝试从函数外部分配工作数组,如:
def quick_sort(list_):
max_depth = 1000
temp_stack_ = np.array( ( max_depth, 2), dtype=np.int32 )
_quick_sort(list_, temp_stack_)
...
@numba.jit(nopython=True)
def _quick_sort(list_, temp_stack_):
...
答案 2 :(得分:0)
对于它的价值,numba已经实现了通用sorted
函数和numpy-array .sort()
方法(我认为)版本0.22。万岁!
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html#built-in-functions http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#other-methods