如何将此树递归更改为尾递归?

时间:2013-10-31 17:12:59

标签: python recursion

我正在编写一个函数ChrNumber,它将阿拉伯数字字符串转换为中文财务数字字符串。我编写了一个树递归表单。但是当我试图得到一个尾递归形式时,我很难处理bit等于6,7或8或10和更大的情况。

你可以在我的问题的最后看到它是如何运作的。

这是树递归解决方案。它有效:

# -*- coding:utf-8 -*-

unitArab=(2,3,4,5,9)
#unitStr=u'十百千万亿' #this is an alternative
unitStr=u'拾佰仟万亿'
unitDic=dict(zip(unitArab,(list(unitStr))))
numArab=list(u'0123456789')
#numStr=u'零一二三四五六七八九' #this is an alternative
numStr=u'零壹贰叁肆伍陆柒捌玖'
numDic=dict(zip(numArab,list(numStr)))
def ChnNumber(s):
    def wrapper(v):
        'this is to adapt the string to a abbreviation'
        if u'零零' in v:
            return wrapper(v.replace(u'零零',u'零'))
        return v[:-1] if v[-1]==u'零' else v
    def recur(s,bit):
        'receives the number sting and its length'
        if bit==1:
            return numDic[s]
        if s[0]==u'0':
            return wrapper(u'%s%s' % (u'零',recur(s[1:],bit-1)))
        if bit<6 or bit==9:
            return wrapper(u'%s%s%s' % (numDic[s[0]],unitDic[bit],recur(s[1:],bit-1)))
        'below is the hard part to be converted to tail-recurion'
        if bit<9:
            return u'%s%s%s' % (recur(s[:-4],bit-4),u"万",recur(s[-4:],4))
        if bit>9:
            return u'%s%s%s' % (recur(s[:-8],bit-8),u"亿",recur(s[-8:],8))
    return recur(s,len(s))

我的尝试版本仅在recur函数中,我使用闭包res并将bit移到recur内,因此参数较少。:

res=[]
def recur(s):
    bit=len(s)
    print s,bit,res
    if bit==0:
        return ''.join(res)
    if bit==1:
        res.append(numDic[s])
        return recur(s[1:])
    if s[0]==u'0':
        res.append(u'零')
        return recur(s[1:])
    if bit<6 or bit==9:
        res.append(u'%s%s' %(numDic[s[0]],unitDic[bit]))
        return recur(s[1:])
    if bit<9:
        #...can't work it out
    if bit>9:
        #...can't work it out

测试代码是:

for i in range(17):
    v1='9'+'0'*(i+1)
    v2='9'+'0'*i+'9'
    v3='1'*(i+2)
    print '%s->%s\n%s->%s\n%s->%s'% (v1,ChnNumber(v1),v2,ChnNumber(v2),v3,ChnNumber(v3))

应输出:

>>> 
90->玖拾
99->玖拾玖
11->壹拾壹
900->玖佰
909->玖佰零玖
111->壹佰壹拾壹
9000->玖仟
9009->玖仟零玖
1111->壹仟壹佰壹拾壹
90000->玖万
90009->玖万零玖
11111->壹万壹仟壹佰壹拾壹
900000->玖拾万
900009->玖拾万零玖
111111->壹拾壹万壹仟壹佰壹拾壹
9000000->玖佰万
9000009->玖佰万零玖
1111111->壹佰壹拾壹万壹仟壹佰壹拾壹
90000000->玖仟万
90000009->玖仟万零玖
11111111->壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000->玖亿
900000009->玖亿零玖
111111111->壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000->玖拾亿
9000000009->玖拾亿零玖
1111111111->壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000->玖佰亿
90000000009->玖佰亿零玖
11111111111->壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000->玖仟亿
900000000009->玖仟亿零玖
111111111111->壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000000->玖万亿
9000000000009->玖万亿零玖
1111111111111->壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000000->玖拾万亿
90000000000009->玖拾万亿零玖
11111111111111->壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000000->玖佰万亿
900000000000009->玖佰万亿零玖
111111111111111->壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000000000->玖仟万亿
9000000000000009->玖仟万亿零玖
1111111111111111->壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000000000->玖亿亿
90000000000000009->玖亿亿零玖
11111111111111111->壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000000000->玖拾亿亿
900000000000000009->玖拾亿亿零玖
111111111111111111->壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹

2 个答案:

答案 0 :(得分:1)

Python不支持尾调用消除和尾调用优化。但是,有很多方法可以模仿这种方法(Trampolines是其他语言中使用最广泛的方法。)

尾调用递归函数应该类似于以下伪代码:

def tail_call(*args, acc):
  if condition(*args):
    return acc
  else:
    # Operations happen here, producing new_args and new_acc
    return tail_call(*new_args, new_acc)

对于你的例子,我不会形成对任何东西的封闭,因为你正在引入副作用和有状态的操纵。相反,任何需要修改的东西都应该独立于其他所有东西进行修改。这使得更容易推理。

复制您尝试更改的内容(使用string.copy作为最终输出)并将其作为参数传递给下一个递归调用。这就是acc变量发挥作用的地方。到目前为止,它正在“累积”所有更改。

经典蹦床可以来自this snippet。在那里,它们将函数包装在一个对象中,该对象最终会导致结果或返回另一个应该被调用的函数对象。我更喜欢这种方法,因为我觉得它更容易推理。

这不是唯一的方法。看看this code snippet。当它到达“解决”条件的点并且它抛出异常以逃避无限循环时,就会发生“魔法”。

最后,您可以阅读有关Trampolines hereherehere的信息。

答案 1 :(得分:0)

这些天我一直在研究这个问题。现在,我解决了!

注意,不仅仅是尾递归,它也是纯函数式编程!

关键是以不同的方式思考(树递归版本从左到右处理数字,而此版本是从右到左)

unitDic=dict(zip(range(8),u'拾佰仟万拾佰仟亿'))
numDic=dict(zip('0123456789',u'零壹贰叁肆伍陆柒捌玖'))
wapDic=[(u'零拾',u'零'),(u'零佰',u'零'),(u'零仟',u'零'),
        (u'零万',u'万'),(u'零亿',u'亿'),(u'亿万',u'亿'),
        (u'零零',u'零'),]

#pure FP
def ChnNumber(s):
    def wrapper(s,wd=wapDic):
        def rep(s,k,v):
            if k in s:
                return rep(s.replace(k,v),k,v)
            return s    
        if not wd:
            return s
        return wrapper(rep(s,*wd[0]),wd[1:])
    def recur(s,acc='',ind=0):        
        if s=='':
            return acc
        return recur(s[:-1],numDic[s[-1]]+unitDic[ind%8]+acc,ind+1)
    def end(s):
        if s[-1]!='0':
            return numDic[s[-1]]
        return ''
    def result(start,end):
        if end=='' and start[-1]==u'零':
            return start[:-1]
        return start+end
    return result(wrapper(recur(s[:-1])),end(s))

for i in range(18):    
    v1='9'+'0'*(i+1)
    v2='9'+'0'*i+'9'
    v3='1'*(i+2)
    print ('%s->%s\n%s->%s\n%s->%s'% (v1,ChnNumber(v1),v2,ChnNumber(v2),v3,ChnNumber(v3)))

如果任何人说面对一个巨大的数字(类似十亿数字的数字)它将不起作用,是的,我承认,但是这个版本可以解决它(虽然它不会是纯粹的FP而是纯粹的FP不需要这个版本..):

class TailCaller(object) :
    def __init__(self, f) :
        self.f = f
    def __call__(self, *args, **kwargs) :
        ret = self.f(*args, **kwargs)
        while type(ret) is TailCall :
            ret = ret.handle()
        return ret

class TailCall(object) :
    def __init__(self, call, *args, **kwargs) :
        self.call = call
        self.args = args
        self.kwargs = kwargs
    def handle(self) :
        if type(self.call) is TailCaller :
            return self.call.f(*self.args, **self.kwargs)
        else :
            return self.f(*self.args, **self.kwargs)

def ChnNumber(s):
    def wrapper(s,wd=wapDic):
        @TailCaller
        def rep(s,k,v):
            if k in s:
                return TailCall(rep,s.replace(k,v),k,v)
            return s    
        if not wd:
            return s
        return wrapper(rep(s,*wd[0]),wd[1:])
    @TailCaller
    def recur(s,acc='',ind=0):        
        if s=='':
            return acc
        return TailCall(recur,s[:-1],numDic[s[-1]]+unitDic[ind%8]+acc,ind+1)
    def end(s):
        if s[-1]!='0':
            return numDic[s[-1]]
        return ''
    def result(start,end):
        if end=='' and start[-1]==u'零':
            return start[:-1]
        return start+end
    return result(wrapper(recur(s[:-1])),end(s))