我有一个很长的数学公式(只是为了让你进入上下文:它有293095个字符),实际上它将是python函数的主体。此函数具有15
输入参数,如:
def math_func(t,X,P,n1,n2,R,r):
x,y,z = X
a,b,c = P
u1,v1,w1 = n1
u2,v2,w2 = n2
return <long math formula>
该公式使用简单的数学运算+ - * ** /
和一个函数调用arctan
。这是它的摘录:
r*((-16*(r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2 + ...
现在重点是,在实践中,此功能的批量评估将针对P,n1,n2,R
和r
的固定值进行,这会将 free 变量的集合减少到仅四,“理论上”参数较少的公式应该更快。
所以问题是:如何在Python中实现此优化?
我知道我可以将所有内容都放在一个字符串中并执行某种replace
,compile
和eval
,如
formula = formula.replace('r','1').replace('R','2')....
code = compile(formula,'formula-name','eval')
math_func = lambda t,x,y,z: eval(code)
如果某些操作(如电源)被其值替换会很好,例如18*r**6*t*u1**2*u2**2*v1**2
应该成为18*t
的{{1}}。我认为r=u1=u2=v1=1
应该这样做,但无论如何我不确定。 compile
是否实际执行此优化?
我的解决方案加快了计算速度,但是如果我可以更多地进行计算,那就更好了。注意:最好在标准Python中(我可以稍后尝试Cython)。
总的来说,我很有兴趣以 pythonic 的方式来实现我的目标,可能还有一些额外的库:什么是相当不错的方式?我的解决方案是良好的方法吗?
编辑:(提供更多背景信息)
巨大的表达式是在圆弧上积分的符号线的输出。弧在空间中由半径compile
,两个正常法向量(如2D版本中的x和y轴)r
,n1=(u1,v1,w1)
和中心{{1}给出}。其余的是我正在为我正在集成的函数执行集成n2=(u2,v2,w2)
和参数P=(a,b,c)
的点。
X=(x,y,z)
和R
只需要很长时间来计算,实际输出来自Sympy
。
如果你对这里的公式感到好奇,那就是(伪伪代码):
Maple
答案 0 :(得分:2)
你可以使用Sympy:
>>> from sympy import symbols
>>> x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r = symbols("x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r")
>>> r=u1=u2=v1=1
>>> a = 18*r**6*t*u1**2*u2**2*v1**2
>>> a
18*t
然后你可以创建一个像这样的Python函数:
>>> from sympy import lambdify
>>> f = lambdify(t, a)
>>> f(1)
18
f
函数确实只是18*t
:
>>> import dis
>>> dis.dis(f)
1 0 LOAD_CONST 1 (18)
3 LOAD_FAST 0 (_Dummy_18)
6 BINARY_MULTIPLY
7 RETURN_VALUE
答案 1 :(得分:1)
以下是我将如何处理此问题:
compile()
您的函数指向AST(抽象语法树)而不是正常的字节码函数 - 有关详细信息,请参阅标准ast
模块。compile()
将AST转换为正常函数。称之为,并希望速度增加足够的数量来证明所有预优化的合理性。请注意,Python永远不会自己优化像1 * X这样的东西,因为它无法知道X在运行时的类型 - 它可能是一个以任意方式实现乘法运算的类的实例,所以结果不一定是X.只有你知道所有变量都是普通数,遵守通常的算术规则,才能使这种优化有效。
答案 2 :(得分:1)
解决此类问题的“正确方法”是以下一项或多项:
我们暂时说一下,接近1,3和4是不可用的,你必须在Python中这样做。然后,简化和“提升”常见的子表达式是您的主要工具。
好消息是,有很多很多的机会。例如,表达式r**6
重复26次。您可以通过简单地分配r_6 = r ** 6
一次,然后在每次发生时替换r**6
来保存25个计算。
当您开始在此处查找常用表达式时,您会发现无处不在。机械化这个过程真好,对吗?通常,这需要完整表达式解析器(例如来自ast
模块)并且是指数时优化问题。但你的表达有点特殊。虽然漫长而多变,但并不是特别复杂。它几乎没有内部括号分组,因此我们可以采用更快更脏的方法。
在如何之前,生成的代码是:
sa = r**6 # 26 occurrences
sb = u1**2 # 5 occurrences
sc = u2**2 # 5 occurrences
sd = v1**2 # 5 occurrences
se = u1**4 # 4 occurrences
sf = u2**3 # 3 occurrences
sg = u1**3 # 3 occurrences
sh = v1**4 # 3 occurrences
si = u2**4 # 3 occurrences
sj = v1**3 # 3 occurrences
sk = v2**2 # 1 occurrence
sl = v1**6 # 1 occurrence
sm = v1**5 # 1 occurrence
sn = u1**6 # 1 occurrence
so = u1**5 # 1 occurrence
sp = u2**6 # 1 occurrence
sq = u2**5 # 1 occurrence
sr = 6*sa # 6 occurrences
ss = 3*sa # 5 occurrences
st = ss*t # 5 occurrences
su = 12*sa # 4 occurrences
sv = sa*t # 3 occurrences
sw = v1*v2 # 5 occurrences
sx = sj*v2 # 3 occurrences
sy = 24*sv # 3 occurrences
sz = 15*sv # 2 occurrences
sA = sr*u1 # 2 occurrences
sB = sy*u1 # 2 occurrences
sC = sb*sc # 2 occurrences
sD = st*se # 2 occurrences
# revised formula
sv*sn - sr*so*u2 - sz*se*sc +
20*sa*sg*sf + sz*sb*si - sA*sq -
sv*sp + sD*sd - su*sg*u2*sd -
18*sv*sC*sd + su*u1*sf*sd +
st*si*sd + st*sb*sh - sA*u2*sh -
st*sc*sh + sv*sl - sr*se*sw -
sy*sg*u2*sw + 36*sa*sC*sw +
sB*sf*sw - sr*si*sw -
su*sb*sx - sB*u2*sx +
su*sc*sx - sr*sm*v2 - sD*sk
这避免了81次计算。这只是一个粗略的切入点。甚至结果也可以进一步改善。例如,子表达式sr*sw
和su*sd
也可以预先计算。但是我们将在下一个级别停留一天。
请注意,这不包括起始r*((-16*(
。大多数简化可以(并且需要)在表达式的核心上完成,而不是在其外部术语上完成。所以我暂时把它们除掉了;一旦计算出公共核心,就可以将它们加回来。
你是怎么做到的?
f = """
r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2
""".strip()
from collections import Counter
import re
expre = re.compile('(?<!\w)\w+\*\*\d+')
multre = re.compile('(?<!\w)\w+\*\w+')
expr_saved = 0
stmts = []
secache = {}
seindex = 0
def subexpr(e):
global seindex
cached = secache.get(e)
if cached:
return cached
base = ord('a') if seindex < 26 else ord('A') - 26
name = 's' + chr(seindex + base)
seindex += 1
secache[e] = name
return name
def hoist(e, flat, c):
"""
Hoist the expression e into name defined by flat.
c is the count of how many times seen in incoming
formula.
"""
global expr_saved
assign = "{} = {}".format(flat, e)
s = "{:30} # {} occurrence{}".format(assign, c, '' if c == 1 else 's')
stmts.append(s)
print "{} needless computations quashed with {}".format(c-1, flat)
expr_saved += c - 1
def common_exp(form):
"""
Replace ALL exponentiation operations with a hoisted
sub-expression.
"""
# find the exponentiation operations
exponents = re.findall(expre, form)
# find and count exponentiation operations
expcount = Counter(re.findall(expre, form))
# for each exponentiation, create a hoisted sub-expression
for e, c in expcount.most_common():
hoist(e, subexpr(e), c)
# replace all exponentiation operations with their sub-expressions
form = re.sub(expre, lambda x: subexpr(x.group(0)), form)
return form
def common_mult(f):
"""
Replace multiplication operations with a hoisted
sub-expression if they occur > 1 time. Also, only
replaces one sub-expression at a time (the most common)
because it may affect further expressions
"""
mults = re.findall(multre, f)
for e, c in Counter(mults).most_common():
# unlike exponents, only replace if >1 occurrence
if c == 1:
return f
# occurs >1 time, so hoist
hoist(e, subexpr(e), c)
# replace in loop and return
return re.sub('(?<!\w)' + re.escape(e), subexpr(e), f)
# return f.replace(e, flat(e))
return f
# fix all exponents
form = common_exp(f)
# fix selected multiplies
prev = form
while True:
form = common_mult(form)
if form == prev:
# have converged; no more replacements possible
break
prev = form
print "--"
mults = re.split(r'\s*[+-]\s*', form)
smults = ['*'.join(sorted(terms.split('*'))) for terms in mults]
print smults
# print the hoisted statements and the revised expression
print '\n'.join(stmts)
print
print "# revised formula"
print form
使用正则表达式进行解析是很冒险的事情。这段旅程容易出错,悲伤和遗憾。我通过提升一些并非严格要求的指数来防止不良结果,并将随机值插入前后公式中以确保它们都给出相同的结果。如果这是生产代码,我建议使用“pnt to C”策略。但如果你不能......