在函数

时间:2016-12-28 13:56:29

标签: python optimization

我有一个很长的数学公式(只是为了让你进入上下文:它有293095个字符),实际上它将是python函数的主体。此函数具有15输入参数,如:

def math_func(t,X,P,n1,n2,R,r):
    x,y,z = X
    a,b,c = P
    u1,v1,w1 = n1
    u2,v2,w2 = n2
    return <long math formula>

该公式使用简单的数学运算+ - * ** /和一个函数调用arctan。这是它的摘录:

r*((-16*(r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
 20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
 r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
 18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
 3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
 3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
 24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
 24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
 12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
 12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2 + ...  

现在重点是,在实践中,此功能的批量评估将针对P,n1,n2,Rr的固定值进行,这会将 free 变量的集合减少到仅四,“理论上”参数较少的公式应该更快。

所以问题是:如何在Python中实现此优化?

我知道我可以将所有内容都放在一个字符串中并执行某种replacecompileeval,如

formula = formula.replace('r','1').replace('R','2')....
code = compile(formula,'formula-name','eval')
math_func = lambda t,x,y,z: eval(code)

如果某些操作(如电源)被其值替换会很好,例如18*r**6*t*u1**2*u2**2*v1**2应该成为18*t的{​​{1}}。我认为r=u1=u2=v1=1应该这样做,但无论如何我不确定。 compile是否实际执行此优化?

我的解决方案加快了计算速度,但是如果我可以更多地进行计算,那就更好了。注意:最好在标准Python中(我可以稍后尝试Cython)。

总的来说,我很有兴趣以 pythonic 的方式来实现我的目标,可能还有一些额外的库:什么是相当不错的方式?我的解决方案是良好的方法吗?

编辑:(提供更多背景信息)

巨大的表达式是在圆弧上积分的符号线的输出。弧在空间中由半径compile,两个正常法向量(如2D版本中的x和y轴)rn1=(u1,v1,w1)和中心{{1}给出}。其余的是我正在为我正在集成的函数执行集成n2=(u2,v2,w2)和参数P=(a,b,c)的点。

X=(x,y,z)R只需要很长时间来计算,实际输出来自Sympy

如果你对这里的公式感到好奇,那就是(伪伪代码):

Maple

3 个答案:

答案 0 :(得分:2)

你可以使用Sympy:

>>> from sympy import symbols
>>> x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r = symbols("x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r")
>>> r=u1=u2=v1=1
>>> a = 18*r**6*t*u1**2*u2**2*v1**2
>>> a
18*t

然后你可以创建一个像这样的Python函数:

>>> from sympy import lambdify
>>> f = lambdify(t, a)
>>> f(1)
18

f函数确实只是18*t

>>> import dis
>>> dis.dis(f)
  1           0 LOAD_CONST               1 (18)
              3 LOAD_FAST                0 (_Dummy_18)
              6 BINARY_MULTIPLY
              7 RETURN_VALUE

如果要将生成的代码编译为机器代码,可以尝试JIT编译器,例如NumbaTheanoParakeet

答案 1 :(得分:1)

以下是我将如何处理此问题:

  1. compile()您的函数指向AST(抽象语法树)而不是正常的字节码函数 - 有关详细信息,请参阅标准ast模块。
  2. 遍历AST,用固定值替换对固定参数的所有引用。像macropy这样的库可能对此有用,我没有任何具体的建议。
  3. 再次遍历AST,执行可能启用的任何优化,例如Mult(1,X)=&gt; X.你不必担心两个常量之间的操作,因为Python(自2.6起)已经对它进行了优化。
  4. compile()将AST转换为正常函数。称之为,并希望速度增加足够的数量来证明所有预优化的合理性。
  5. 请注意,Python永远不会自己优化像1 * X这样的东西,因为它无法知道X在运行时的类型 - 它可能是一个以任意方式实现乘法运算的类的实例,所以结果不一定是X.只有你知道所有变量都是普通数,遵守通常的算术规则,才能使这种优化有效。

答案 2 :(得分:1)

解决此类问题的“正确方法”是以下一项或多项:

  1. 寻找更有效的配方
  2. 象征性地简化和减少术语
  3. 使用矢量化(例如NumPy)
  4. 寻找已经优化的低级库(例如,在隐式执行强表达式优化的C或Fortran等语言中,而不是使用 nada 的Python)。
  5. 我们暂时说一下,接近1,3和4是不可用的,你必须在Python中这样做。然后,简化和“提升”常见的子表达式是您的主要工具。

    好消息是,有很多很多的机会。例如,表达式r**6重复26次。您可以通过简单地分配r_6 = r ** 6一次,然后在每次发生时替换r**6来保存25个计算。

    当您开始在此处查找常用表达式时,您会发现无处不在。机械化这个过程真好,对吗?通常,这需要完整表达式解析器(例如来自ast模块)并且是指数时优化问题。但你的表达有点特殊。虽然漫长而多变,但并不是特别复杂。它几乎没有内部括号分组,因此我们可以采用更快更脏的方法。

    在如何之前,生成的代码是:

    sa = r**6                      # 26 occurrences
    sb = u1**2                     # 5 occurrences
    sc = u2**2                     # 5 occurrences
    sd = v1**2                     # 5 occurrences
    se = u1**4                     # 4 occurrences
    sf = u2**3                     # 3 occurrences
    sg = u1**3                     # 3 occurrences
    sh = v1**4                     # 3 occurrences
    si = u2**4                     # 3 occurrences
    sj = v1**3                     # 3 occurrences
    sk = v2**2                     # 1 occurrence
    sl = v1**6                     # 1 occurrence
    sm = v1**5                     # 1 occurrence
    sn = u1**6                     # 1 occurrence
    so = u1**5                     # 1 occurrence
    sp = u2**6                     # 1 occurrence
    sq = u2**5                     # 1 occurrence
    sr = 6*sa                      # 6 occurrences
    ss = 3*sa                      # 5 occurrences
    st = ss*t                      # 5 occurrences
    su = 12*sa                     # 4 occurrences
    sv = sa*t                      # 3 occurrences
    sw = v1*v2                     # 5 occurrences
    sx = sj*v2                     # 3 occurrences
    sy = 24*sv                     # 3 occurrences
    sz = 15*sv                     # 2 occurrences
    sA = sr*u1                     # 2 occurrences
    sB = sy*u1                     # 2 occurrences
    sC = sb*sc                     # 2 occurrences
    sD = st*se                     # 2 occurrences
    
    # revised formula
    sv*sn - sr*so*u2 - sz*se*sc +
    20*sa*sg*sf + sz*sb*si - sA*sq -
    sv*sp + sD*sd - su*sg*u2*sd -
    18*sv*sC*sd + su*u1*sf*sd +
    st*si*sd + st*sb*sh - sA*u2*sh -
    st*sc*sh + sv*sl - sr*se*sw -
    sy*sg*u2*sw + 36*sa*sC*sw +
    sB*sf*sw - sr*si*sw -
    su*sb*sx - sB*u2*sx +
    su*sc*sx - sr*sm*v2 - sD*sk
    

    这避免了81次计算。这只是一个粗略的切入点。甚至结果也可以进一步改善。例如,子表达式sr*swsu*sd也可以预先计算。但是我们将在下一个级别停留一天。

    请注意,这不包括起始r*((-16*(。大多数简化可以(并且需要)在表达式的核心上完成,而不是在其外部术语上完成。所以我暂时把它们除掉了;一旦计算出公共核心,就可以将它们加回来。

    你是怎么做到的?

    f = """
    r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
    20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
    r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
    18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
    3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
    3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
    24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
    24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
    12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
    12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2
    """.strip()
    
    
    from collections import Counter
    import re
    
    expre = re.compile('(?<!\w)\w+\*\*\d+')
    multre = re.compile('(?<!\w)\w+\*\w+')
    
    expr_saved = 0
    stmts = []
    
    
    secache = {}
    seindex = 0
    def subexpr(e):
        global seindex
        cached = secache.get(e)
        if cached:
            return cached
        base = ord('a') if seindex < 26 else ord('A') - 26
        name = 's' + chr(seindex + base)
        seindex += 1
        secache[e] = name
        return name
    
    def hoist(e, flat, c):
        """
        Hoist the expression e into name defined by flat.
        c is the count of how many times seen in incoming
        formula.
        """
        global expr_saved
    
        assign = "{} = {}".format(flat, e)
        s = "{:30} # {} occurrence{}".format(assign, c, '' if c == 1 else 's')
        stmts.append(s)
        print "{} needless computations quashed with {}".format(c-1, flat)
        expr_saved += c - 1
    
    def common_exp(form):
        """
        Replace ALL exponentiation operations with a hoisted
        sub-expression.
        """
        # find the exponentiation operations
        exponents = re.findall(expre, form)
    
        # find and count exponentiation operations
        expcount = Counter(re.findall(expre, form))
    
        # for each exponentiation, create a hoisted sub-expression
        for e, c in expcount.most_common():
            hoist(e, subexpr(e), c)
    
        # replace all exponentiation operations with their sub-expressions
        form = re.sub(expre, lambda x: subexpr(x.group(0)), form)
        return form
    
    
    def common_mult(f):
        """
        Replace multiplication operations with a hoisted
        sub-expression if they occur > 1 time. Also, only
        replaces one sub-expression at a time (the most common)
        because it may affect further expressions
        """
        mults = re.findall(multre, f)
        for e, c in Counter(mults).most_common():
            # unlike exponents, only replace if >1 occurrence
            if c == 1:
                return f
            # occurs >1 time, so hoist
            hoist(e, subexpr(e), c)
            # replace in loop and return
            return re.sub('(?<!\w)' + re.escape(e), subexpr(e), f)
            # return f.replace(e, flat(e))
        return f
    
    # fix all exponents
    form = common_exp(f)
    
    # fix selected multiplies
    prev = form
    while True:
        form = common_mult(form)
        if form == prev:
            # have converged; no more replacements possible
            break
        prev = form
    
    print "--"
    mults = re.split(r'\s*[+-]\s*', form)
    smults = ['*'.join(sorted(terms.split('*'))) for terms in mults]
    print smults
    
    # print the hoisted statements and the revised expression
    print '\n'.join(stmts)
    print
    print "# revised formula"
    print form
    

    使用正则表达式进行解析是很冒险的事情。这段旅程容易出错,悲伤和遗憾。我通过提升一些并非严格要求的指数来防止不良结果,并将随机值插入前后公式中以确保它们都给出相同的结果。如果这是生产代码,我建议使用“pnt to C”策略。但如果你不能......