在python中的Burrows-Wheeler中的性能问题

时间:2014-01-23 01:58:33

标签: python performance algorithm

我试图在python中实现Burrows-Wheeler转换。 (这是在线课程的任务之一,但我希望我做了一些工作才有资格寻求帮助。)

该算法的工作原理如下。取一个以特殊字符结尾的字符串(在我的情况下是$)并从该字符串创建所有循环字符串。按字母顺序对所有这些字符串进行排序,使特殊字符始终小于任何其他字符。在此之后获取每个字符串的最后一个元素。

这给了我一个oneliner:

''.join([i[-1] for i in sorted([text[i:] + text[0:i] for i in xrange(len(text))])]

对于相当大的字符串来说哪个是正确且合理的快(这足以解决问题):

 60 000 chars - 16 secs
 40 000 chars - 07 secs
 25 000 chars - 02 secs

但是当我尝试处理一个包含数百万个字符的巨大字符串时,我失败了(处理时间太长)。

我认为问题在于在内存中存储太多字符串。

有没有办法克服这个问题?

P.S。只是想指出这也可能看起来像是一个家庭作业问题,我的解决方案已经通过了分级机,我只是想找到一种方法来加快速度。此外,我并没有破坏其他人的乐趣,因为如果他们想找到解决方案,维基文章就有一个类似于我的解决方案。我还检查了this questio n这听起来很相似,但回答了一个更难的问题,如何解码用这种算法编码的字符串。

3 个答案:

答案 0 :(得分:10)

使用长字符串制作所有字符串切片需要很长时间。它至少 O(N ^ 2)(因为你创建了N个N长度的字符串,并且每个字符串都必须从原始数据库中复制到内存中),这会破坏整体性能并且使排序无关紧要。更不用说内存要求了!

而不是实际切片字符串,下一个想法是按照结果字符串 比较的顺序排序用于创建循环字符串的i值 - 实际上没有创造它。事实证明这有点棘手。 (删除/编辑了一些错误的内容;请参阅@TimPeters的回答。

我在这里采用的方法是绕过标准库 - 这使得“按需”(虽然不是不可能)比较这些字符串“按需” - 并进行自己的排序。这里算法的自然选择是基数排序,因为我们无论如何都需要一次考虑一个字符。

让我们先设置好。我正在编写版本3.2的代码,所以季节尝试。 (特别是在3.3及更高版本中,我们可以利用yield from。)我使用以下导入:

from random import choice
from timeit import timeit
from functools import partial

我写了一个这样的通用基数排序函数:

def radix_sort(values, key, step=0):
    if len(values) < 2:
        for value in values:
            yield value
        return

    bins = {}
    for value in values:
        bins.setdefault(key(value, step), []).append(value)

    for k in sorted(bins.keys()):
        for r in radix_sort(bins[k], key, step + 1):
            yield r

当然,我们不需要是通用的(我们的'bins'只能用单个字符标记,并且可能你真的意味着将算法应用于< strong> bytes ;)),但它不会受到伤害。还有可重复使用的东西,对吧?无论如何,这个想法很简单:我们处理一个基本情况,然后我们根据key函数的结果将每个元素放入一个“bin”中,然后我们按照排序的bin顺序从bin中提取值,递归地对每个元素进行排序bin的内容。

界面要求key(value, n)向我们提供n的{​​{1}}“基数”。因此,对于简单的情况,比如直接比较字符串,这可能很简单,如value。但是,这里的想法是根据该点的字符串中的数据(周期性地考虑)将索引与字符串进行比较。所以让我们定义一个关键:

lambda v, n: return v[n]

现在获得正确结果的诀窍是要记住我们在概念上加入了我们实际上没有创建的字符串的最后一个字符。如果我们考虑使用索引def bw_key(text, value, step): return text[(value + step) % len(text)] 创建的虚拟字符串,它的最后一个字符位于索引n,因为我们如何环绕 - 并且片刻的想法会向您确认这仍然有效{{1 }); [但是,当我们向前包装时,我们仍然需要保持字符串索引入界 - 因此在键函数中进行模运算。]

这是一个通用的关键函数,需要在转换n - 1进行比较时引用的n == 0中传递。这就是text进来的地方 - 你也可以随便找一下value,但这可以说是更清洁了,而且我发现它通常也更快。

无论如何,现在我们可以使用密钥轻松编写实际的转换:

functools.partial

很好漂亮。让我们看看它是怎么做的,不是吗?我们需要一个标准来比较它:

lambda

时间例程:

def burroughs_wheeler_custom(text):
    return ''.join(text[i - 1] for i in radix_sort(range(len(text)), partial(bw_key, text)))
    # Notice I've dropped the square brackets; this means I'm passing a generator
    # expression to `join` instead of a list comprehension. In general, this is
    # a little slower, but uses less memory. And the underlying code uses lazy
    # evaluation heavily, so :)

注意我已做过的数学决定def burroughs_wheeler_standard(text): return ''.join([i[-1] for i in sorted([text[i:] + text[:i] for i in range(len(text))])]) 的数量,与def test(n): data = ''.join(choice('abcdefghijklmnopqrstuvwxyz') for i in range(n)) + '$' custom = partial(burroughs_wheeler_custom, data) standard = partial(burroughs_wheeler_standard, data) assert custom() == standard() trials = 1000000 // n custom_time = timeit(custom, number=trials) standard_time = timeit(standard, number=trials) print("custom: {} standard: {}".format(custom_time, standard_time)) 字符串的长度成反比。这应该将用于测试的总时间保持在合理的范围内 - 对吧? ;)(当然,错误,因为我们确定trials算法至少是O(N ^ 2)。)

让我们看看它是如何做的(* drumroll *):

test
哇,哇,这有点令人恐惧。无论如何,正如您所看到的,新方法在短字符串上增加了大量开销,但却使实际排序成为瓶颈而不是字符串切片。 :)

答案 1 :(得分:6)

只需添加一点@KarlKnechtel的即时回复。

首先,加速循环置换提取的“标准方法”就是将两个副本粘贴在一起并直接索引到其中。后:

N = len(text)
text2 = text * 2

然后从索引i开始的循环排列仅为text2[i: i+N],而该排列中的字符j仅为text2[i+j]。无需将两个切片或模数(%)操作粘贴在一起。

其次,内置sort()可以用于此,但是:

  1. 它很时髦; - )
  2. 对于字符串很少的字符串(与字符串的长度相比),Karl的基数排序几乎肯定会更快。
  3. 作为概念验证,这里是Karl代码部分的替代品(虽然这很适合Python 2):

    def burroughs_wheeler_custom(text):
        N = len(text)
        text2 = text * 2
        class K:
            def __init__(self, i):
                self.i = i
            def __lt__(a, b):
                i, j = a.i, b.i
                for k in xrange(N): # use `range()` in Python 3
                    if text2[i+k] < text2[j+k]:
                        return True
                    elif text2[i+k] > text2[j+k]:
                        return False
                return False # they're equal
    
        inorder = sorted(range(N), key=K)
        return "".join(text2[i+N-1] for i in inorder)
    

    请注意,内置sort()的实现为其输入中的每个元素计算一次密钥, 保存这些结果的持续时间为那种。在这种情况下,结果是只记得起始索引的惰性小K个实例,并且其__lt__方法一次比较一个字符对,直到“小于!”。或“大于!”已经解决了。

答案 2 :(得分:1)

我同意之前的回答,python中的字符串/列表切片在执行大量算法计算时成为瓶颈。这个想法是不切片

[编辑:不切片,但列表索引。如果使用array.array而不是list,则执行时间减少到一半。索引数组很简单,索引列表是一个更复杂的过程)]

这里有一个更实用的问题解决方案。

这个想法是有一个生成器,它将充当切片器(rslice)。这与itertools.islice的想法类似,但是当它到达结尾时它会转到字符串的开头。并且在到达创建它时指定的起始位置之前它将停止。有了这个技巧,你不会在内存中复制任何子串,所以最后你只有指针 移动你的字符串而不是在任何地方创建副本。

所以我们创建一个包含[rslices,切片的lastchar]的列表 我们使用rslice作为关键字进行排序(如cf sort函数中所示)。

当它被排序时,你只需要为列表中的每个元素收集第二个元素(先前存储的切片的最后一个元素)。

from itertools import izip
def cf(i1,i2):
    for i,j in izip(i1[0](),i2[0]()): # We grab the the first element (is a lambda) and execute it to get the generator
        if i<j: return -1
        elif i>j: return 1
    return 0

def rslice(cad,pos): # Slice that rotates through the string (it's a generator)
    pini=pos
    lc=len(cad)
    while pos<lc:
        yield cad[pos]
        pos+=1
    pos=0
    while pos<pini-1:
        yield cad[pos]
        pos+=1

def lambdagen(start,cad): # Closure to hold a generator
    return lambda: rslice(cad,start)

def bwt(txt):
    lt=len(txt)
    arry=list(txt)+[None]

    l=[(lambdagen(0,arry),None)]+[(lambdagen(i,arry),arry[i-1]) for i in range(1,lt+1)]
    # What we keep in the list is the generator for the rotating-slice, plus the 
    # last character of the slice, so we save the time of going through the whole 
    # string to get the last character

    l.sort(cmp=cf)   # We sort using our cf function
    return [i[1] for i in l]

print bwt('Text I want to apply BTW to :D')

# ['D', 'o', 'y', 't', 'o', 'W', 't', 'I', ' ', ' ', ':', ' ', 'B', None, 'T', 'w', ' ', 
# 'T', 'p', 'a', 't', 't', 'p', 'a', 'x', 'n', ' ', ' ', ' ', 'e', 'l']

编辑:使用数组(执行时间减少2):

def bwt(txt):
    lt=len(txt)
    arry=array.array('h',[ord(i) for i in txt])
    arry.append(-1)

    l=[(lambdagen(0,arry),None)]+[(lambdagen(i,arry),arry[i-1]) for i in range(1,lt+1)]

    l.sort(cmp=cf)
    return [i[1] for i in l]