为什么我的基数排序python实现比快速排序慢?

时间:2011-11-26 15:53:36

标签: python quicksort radix-sort

我使用 SciPy 中的数组重写了来自 Wikipedia 的Python的原始基数排序算法,以获得性能并减少代码长度,这是我设法完成的。然后我从 Literate Programming 中获取了经典(内存中,基于数据透视)的快速排序算法,并比较了它们的性能。

我期望基数排序会快速超过一定的阈值,但事实并非如此。此外,我发现Erik Gorset's Blog's提出问题" 基数排序比整数数组的快速排序更快吗?"。答案是那个

  

..基准测试显示,对于大型阵列,MSB就地基数排序的速度始终比快速排序快3倍。

不幸的是,我无法重现结果;不同之处在于(a)Erik选择Java而不是Python,(b)他使用 MSB就地基数排序,而我只是在Python字典中填写 buckets

根据理论,与快速排序相比,基数排序应该更快(线性);但显然它在很大程度上取决于实施。那我的错误在哪里?

以下是比较两种算法的代码:

from sys   import argv
from time  import clock

from pylab import array, vectorize
from pylab import absolute, log10, randint
from pylab import semilogy, grid, legend, title, show

###############################################################################
# radix sort
###############################################################################

def splitmerge0 (ls, digit): ## python (pure!)

    seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
    buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}

    return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])

def splitmergeX (ls, digit): ## python & numpy

    seq = array (vectorize (lambda n: ((n // 10 ** digit) % 10, n)) (ls)).T
    buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}

    return array (reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), []))

def radixsort (ls, fn = splitmergeX):

    return reduce (fn, xrange (int (log10 (absolute (ls).max ()) + 1)), ls)

###############################################################################
# quick sort
###############################################################################

def partition (ls, start, end, pivot_index):

    lower = start
    upper = end - 1

    pivot = ls[pivot_index]
    ls[pivot_index] = ls[end]

    while True:

        while lower <= upper and ls[lower] <  pivot: lower += 1
        while lower <= upper and ls[upper] >= pivot: upper -= 1
        if lower > upper: break

        ls[lower], ls[upper] = ls[upper], ls[lower]

    ls[end] = ls[lower]
    ls[lower] = pivot

    return lower

def qsort_range (ls, start, end):

    if end - start + 1 < 32:
        insertion_sort(ls, start, end)
    else:
        pivot_index = partition (ls, start, end, randint (start, end))
        qsort_range (ls, start, pivot_index - 1)
        qsort_range (ls, pivot_index + 1, end)

    return ls

def insertion_sort (ls, start, end):

    for idx in xrange (start, end + 1):
        el = ls[idx]
        for jdx in reversed (xrange(0, idx)):
            if ls[jdx] <= el:
                ls[jdx + 1] = el
                break
            ls[jdx + 1] = ls[jdx]
        else:
            ls[0] = el

    return ls

def quicksort (ls):

    return qsort_range (ls, 0, len (ls) - 1)

###############################################################################
if __name__ == "__main__":
###############################################################################

    lower = int (argv [1]) ## requires: >= 2
    upper = int (argv [2]) ## requires: >= 2
    color = dict (enumerate (3*['r','g','b','c','m','k']))

    rslbl = "radix sort"
    qslbl = "quick sort"

    for value in xrange (lower, upper):

        #######################################################################

        ls = randint (1, value, size=value)

        t0 = clock ()
        rs = radixsort (ls)
        dt = clock () - t0

        print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
        semilogy (value, dt, '%s.' % color[int (log10 (value))], label=rslbl)

        #######################################################################

        ls = randint (1, value, size=value)

        t0 = clock ()
        rs = quicksort (ls)
        dt = clock () - t0

        print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
        semilogy (value, dt, '%sx' % color[int (log10 (value))], label=qslbl)

    grid ()
    legend ((rslbl,qslbl), numpoints=3, shadow=True, prop={'size':'small'})
    title ('radix & quick sort: #(integer) vs duration [s]')
    show ()

###############################################################################
###############################################################################

以下是比较大小范围为2到1250(水平轴)的整数数组的排序持续时间(以秒为单位)(对数垂直轴)的结果;下曲线属于快速排序:

快速排序在功率变化时是平滑的(例如在10,100或1000),但基数排序只是略微跳跃,但在定性上与快速排序相同,只是慢得多!

2 个答案:

答案 0 :(得分:3)

这里有几个问题。

首先,正如评论中所指出的,您的数据集太小而不能理解复杂性以克服代码中的开销。

接下来,所有那些不必要的函数调用和复制列表的实现效率非常低。以简单的程序方式编写代码几乎总是比功能解决方案更快(对于Python,其他语言在这里会有所不同)。你有一个quicksort的程序实现,所以如果你用相同的样式编写基数排序,即使对于小列表也可能更快。

最后,可能是当你尝试大型列表时,内存管理的开销开始占主导地位。这意味着在小型列表之间有一个有限的窗口,其中实现的效率是主导因素,而大型列表是内存管理的主要因素。

这里有一些使用你的快速排序的代码,但是程序上编写了一个简单的radixsort,但试图避免这么多的数据复制。您会看到,即使是短列表,它也会超过快速排序,但更有趣的是随着数据大小的增加,快速排序和基数排序之间的比例也会随着内存管理开始占主导地位而再次开始下降(像释放一样简单的事情) 1,000,000个项目的清单需要很长时间):

from random import randint
from math import log10
from time import clock
from itertools import chain

def splitmerge0 (ls, digit): ## python (pure!)

    seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
    buf = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}

    return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])

def splitmerge1 (ls, digit): ## python (readable!)
    buf = [[] for i in range(10)]
    divisor = 10 ** digit
    for n in ls:
        buf[(n//divisor)%10].append(n)
    return chain(*buf)

def radixsort (ls, fn = splitmerge1):
    return list(reduce (fn, xrange (int (log10 (max(abs(val) for val in ls)) + 1)), ls))

###############################################################################
# quick sort
###############################################################################

def partition (ls, start, end, pivot_index):

    lower = start
    upper = end - 1

    pivot = ls[pivot_index]
    ls[pivot_index] = ls[end]

    while True:

        while lower <= upper and ls[lower] <  pivot: lower += 1
        while lower <= upper and ls[upper] >= pivot: upper -= 1
        if lower > upper: break

        ls[lower], ls[upper] = ls[upper], ls[lower]

    ls[end] = ls[lower]
    ls[lower] = pivot

    return lower

def qsort_range (ls, start, end):

    if end - start + 1 < 32:
        insertion_sort(ls, start, end)
    else:
        pivot_index = partition (ls, start, end, randint (start, end))
        qsort_range (ls, start, pivot_index - 1)
        qsort_range (ls, pivot_index + 1, end)

    return ls

def insertion_sort (ls, start, end):

    for idx in xrange (start, end + 1):
        el = ls[idx]
        for jdx in reversed (xrange(0, idx)):
            if ls[jdx] <= el:
                ls[jdx + 1] = el
                break
            ls[jdx + 1] = ls[jdx]
        else:
            ls[0] = el

    return ls

def quicksort (ls):

    return qsort_range (ls, 0, len (ls) - 1)

if __name__=='__main__':
    for value in 1000, 10000, 100000, 1000000, 10000000:
        ls = [randint (1, value) for _ in range(value)]
        ls2 = list(ls)
        last = -1
        start = clock()
        ls = radixsort(ls)
        end = clock()
        for i in ls:
            assert last <= i
            last = i
        print("rs %d: %0.2fs" % (value, end-start))
        tdiff = end-start
        start = clock()
        ls2 = quicksort(ls2)
        end = clock()
        last = -1
        for i in ls2:
            assert last <= i
            last = i
        print("qs %d: %0.2fs %0.2f%%" % (value, end-start, ((end-start)/tdiff*100)))

运行时的输出是:

C:\temp>c:\python27\python radixsort.py
rs 1000: 0.00s
qs 1000: 0.00s 212.98%
rs 10000: 0.02s
qs 10000: 0.05s 291.28%
rs 100000: 0.19s
qs 100000: 0.58s 311.98%
rs 1000000: 2.47s
qs 1000000: 7.07s 286.33%
rs 10000000: 31.74s
qs 10000000: 86.04s 271.08%

修改: 只是为了澄清。这里的快速排序实现非常友好,它就地排序,所以无论列表有多大,它都只是在不复制数据的情况下改组数据。原始的radixsort有效地为每个数字复制两次列表:一次进入较小的列表,然后在连接列表时再次复制。使用itertools.chain避免了第二个副本,但仍然有很多内存分配/释放。 (同样'两次'是近似值,因为列表追加确实涉及额外复制,即使它是分摊O(1)所以我应该说'与两倍成比例'。)

答案 1 :(得分:0)

您的数据表示非常昂贵。为什么你的桶使用 hashmap ?为什么要使用base10表示来计算对数(=计算代价很高)?

避免使用lambda表达式,我不认为python可以很好地优化它们。

也许从为基准测试排序10字节字符串开始。并且:没有Hashmaps和类似的昂贵数据结构。