我在Python / Numba中编写了一个mergesort:
import numba as nb
import numpy as np
@nb.jit( nopython=True )
def merge( x ):
n = x.shape[0]
width=1
r = x.copy()
tgt = np.empty_like( r )
while width<n:
i=0
while i<n:
istart = i
imid = i+width
iend = imid+width
# i has become i+2*width
i = iend
if imid>n:
imid = n
if iend>n:
iend=n
_merge( r, tgt, istart, imid, iend)
# Swap them round, so that the partially sorted tgt becomes the result,
# and the result becomes a new target buffer
r, tgt = tgt, r
width*=2
return r
@nb.jit( nopython=True )
def _merge( src_arr, tgt_arr, istart, imid, iend ):
""" The merge part of the merge sort """
i0 = istart
i1 = imid
for ipos in range( istart, iend ):
if ( i0<imid ) and ( ( i1==iend ) or ( src_arr[ i0 ] < src_arr[ i1 ] ) ):
tgt_arr[ ipos ] = src_arr[ i0 ]
i0+=1
else:
tgt_arr[ ipos ] = src_arr[ i1 ]
i1+=1
我为它写了一个测试:
def test_merge_multi(self):
n0 = 21
n1 = 100
for n in range( n0, n1 ):
x = np.random.random_integers( 0, n, size=n )
with Timer( 'nb' ) as t0:
r = sas.merge( x )
with Timer( 'np' ) as t1:
e = np.sort( x, kind='merge' )
#print( 'r:%s'%str(r))
#print( 'e:%s'%str(e))
print( 'nb/np performance %s'%(t0.interval/t1.interval ))
np.testing.assert_equal( e, r )
我使用了这个Timer类:
import time
class Timer:
def __init__(self,title=None):
self.title=title
def __enter__(self):
if self.title:
print( 'Beginning {0}'.format( self.title ) )
self.start = time.clock()
return self
def __exit__(self, *args):
self.end = time.clock()
self.interval = self.end - self.start
if self.title:
print( '{1} took {0:0.4f} seconds'.format( self.interval, self.title ) )
else:
pass#
#print( 'Timer took {0:0.4f} seconds'.format( self.interval ) )
测试结果如下:
nb/np performance 9307.846153856719
nb/np performance 1.1428571428616743
nb/np performance 0.7142857142925115
nb/np performance 0.8333333333302494
nb/np performance 0.9999999999814962
nb/np performance 0.9999999999777955
nb/np performance 0.8333333333456692
nb/np performance 0.8333333333302494
nb/np performance 1.0
nb/np performance 0.8333333333456692
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 0.8333333333456692
nb/np performance 0.9999999999814962
nb/np performance 1.0
nb/np performance 0.9999999999814962
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0000000000185036
nb/np performance 1.2000000000044408
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.0000000000185036
nb/np performance 1.2000000000088817
nb/np performance 1.0
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 0.9999999999814962
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.0
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666728345
nb/np performance 1.1666666666512469
nb/np performance 1.1666666666512469
nb/np performance 1.0
nb/np performance 1.1666666666728345
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333024937
nb/np performance 1.3333333333456692
nb/np performance 1.1428571428435483
nb/np performance 1.3333333333209976
nb/np performance 1.1666666666728345
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333209976
nb/np performance 1.000000000012336
nb/np performance 1.1428571428616743
nb/np performance 1.3333333333456692
nb/np performance 1.3333333333209976
nb/np performance 1.1428571428616743
nb/np performance 1.1428571428616743
nb/np performance 1.3333333333456692
nb/np performance 1.499999999990748
nb/np performance 1.2857142857074884
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857029569
nb/np performance 1.1428571428616743
nb/np performance 1.1428571428435483
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857233488
nb/np performance 1.2857142857029569
nb/np performance 1.1249999999895917
nb/np performance 1.2857142857029569
nb/np performance 1.2857142857233488
nb/np performance 1.4285714285623656
nb/np performance 1.249999999993061
nb/np performance 1.1250000000034694
nb/np performance 1.2857142857029569
绘制结果(来自不同的运行):
长跑的结果:
请注意,对于n&lt; = 20,numpy在调用mergesort时使用插入排序:https://github.com/numpy/numpy/blob/master/numpy/core/src/npysort/mergesort.c.src
所以你可以看到,对于n的小值,mergesort的numba版本胜过numpy版本。
然而,随着n越大,numpy的表现始终优于numba因子。
这是为什么?我怎么能优化numba版本以击败所有n的numpy版本?
答案 0 :(得分:2)
如果你的人生目标是击败numpy的实现,你也可以尝试更密切地重现那里正在做的事情。与您实施的算法在算法上有两个主要区别:
NumPy通过实际递归实现自顶向下递归。您正在使用自下而上的方法,这会使您免于递归堆栈,但通常最终会产生不平衡的合并,从而降低效率。
虽然您的乒乓缓冲区方法很简洁,但您需要移动的数据超出严格要求。像NumPy那样进行就地排序将减少至少75%实现所需的总内存大小,这也可能有助于缓存性能。
不考虑Numba魔术,这与NumPy合并的内部运作非常接近:
def _mergesort(x, lo, hi, buffer):
if hi - lo <= 1:
return
# Python ints don't overflow, so we could do mid = (hi + lo) // 2
mid = lo + (hi - lo) // 2
_mergesort(x, lo, mid, buffer)
_mergesort(x, mid, hi, buffer)
buffer[:mid-lo] = x[lo:mid]
read_left = 0
read_right = mid
write = lo
while read_left < mid - lo and read_right < hi:
if x[read_right] < buffer[read_left]:
x[write] = x[read_right]
read_right += 1
else:
x[write] = buffer[read_left]
read_left += 1
write += 1
# bulk copy of left over entries from left subarray
x[write:read_right] = buffer[read_left:mid-lo]
# Left over entries in the right subarray are already in-place
def mergesort(x):
# Copy input array and flatten it
x = np.array(x, copy=True).ravel()
n = x.size
_mergesort(x, 0, n, np.empty(shape=(n//2,), dtype=x.dtype))
return x