使用Python中的自定义函数解析方程式

时间:2015-01-15 10:04:25

标签: python python-2.7 parsing equation

我有一个数学方程式的字符串,但有一些自定义函数。我需要找到所有这些函数并用一些代码替换它们。

例如,我有一个字符串:

a+b+f1(f2(x,y),x)

我希望代码能够将f2(x,y)替换为x+y^2,将f1(x,y)替换为sin(x+y)

如果支持嵌套函数,那将是理想的,就像在示例中一样。但是,如果不支持嵌套,它仍然有用。

正如我从类似主题中理解的那样,这可以使用像compiler.parse(eq)这样的编译器模块来完成。我如何使用compiler.parse(eq)创建的AST对象重新构建我的字符串,替换所有找到的函数?

我只需要执行替换,然后字符串将在其他程序中使用。不需要评估。

6 个答案:

答案 0 :(得分:7)

这是一个最小的工作示例(+, - , *, /, **二进制和一元操作和函数调用实现)。操作的优先级用括号设置。

比给定示例的功能稍微多一点:

from __future__ import print_function
import ast

def transform(eq,functions):
    class EqVisitor(ast.NodeVisitor):
        def visit_BinOp(self,node):
            #generate("=>BinOp")
            generate("(")
            self.visit(node.left)
            self.visit(node.op)
            #generate("ici",str(node.op),node._fields,node._attributes)
            #generate(dir(node.op))
            self.visit(node.right)
            generate(")")
            #ast.NodeVisitor.generic_visit(self,node)
        def visit_USub(self,node):
            generate("-")
        def visit_UAdd(self,node):
            generate("+")

        def visit_Sub(self,node):
            generate("-")
        def visit_Add(self,node):
            generate("+")
        def visit_Pow(self,node):
            generate("**")
        def visit_Mult(self,node):
            generate("*")
        def visit_Div(self,node):
            generate("/")
        def visit_Name(self,node):
            generate(node.id)
        def visit_Call(self,node):
            debug("function",node.func.id)
            if node.func.id in functions:
                debug("defined function")
                func_visit(functions[node.func.id],node.args)
                return
            debug("not defined function",node.func.id)
            #generate(node._fields)
            #generate("args")
            generate(node.func.id)
            generate("(")
            sep = ""
            for arg in node.args:
                generate (sep)
                self.visit(arg)
                sep=","
            generate(")")
        def visit_Num(self,node):
            generate(node.n)
        def generic_visit(self, node):


            debug ("\n",type(node).__name__)
            debug (node._fields)
            ast.NodeVisitor.generic_visit(self, node)

    def func_visit(definition,concrete_args):
        class FuncVisitor(EqVisitor):
            def visit_arguments(self,node):
                #generate("visit arguments")
                #generate(node._fields)
                self.arguments={}
                for concrete_arg,formal_arg in zip(concrete_args,node.args):
                    #generate(formal_arg._fields)
                    self.arguments[formal_arg.id]=concrete_arg
                debug(self.arguments)
            def visit_Name(self,node):
                debug("visit Name",node.id)
                if node.id in self.arguments:
                    eqV.visit(self.arguments[node.id])
                else:
                    generate(node.id)


        funcV=FuncVisitor()
        funcV.visit(ast.parse(definition))

    eqV=EqVisitor()
    result = []
    def generate(s):
        #following line maybe usefull for debug
        debug(str(s))
        result.append(str(s))
    eqV.visit(ast.parse(eq,mode="eval"))
    return "".join(result)
def debug(*args,**kwargs):
    #print(*args,**kwargs)
    pass

用法:

functions= {
    "f1":"def f1(x,y):return x+y**2",
    "f2":"def f2(x,y):return sin(x+y)",
}
eq="-(a+b)+f1(f2(+x,y),z)*4/365.12-h"
print(transform(eq,functions))

结果

((-(a+b)+(((sin((+x+y))+(z**2))*4)/365.12))-h)

警告

该代码适用于Python 2.7,因为它依赖于AST并不能保证与另一个版本的Python一起使用。 Python 3版本不起作用。

答案 1 :(得分:3)

完全替换非常棘手。这是我尝试这样做的。在这里我们可以成功内联表达式, 但并非在所有情况下。此代码仅适用于AST,由ast模块生成。并使用codegen将其字符串化回代码。 The stringifying of ast and modifying ast in general is covered in other SO Q/A: "Parse a .py file, read the AST, modify it, then write back the modified source code"

首先我们定义一些助手:

import ast
import codegen
import copy

def parseExpr(expr):
    # Strip:
    # Module(body=[Expr(value=
    return ast.parse(expr).body[0].value

def toSource(expr):
    return codegen.to_source(expr)

之后我们使用NodeTransformer定义替换函数。 例如:

substitute(parseExpr("a + b"), { "a": parseExpr("1") }) # 1 + b

需要同时替换多个变量才能正确避免恶劣的情况。 例如,将ab替换为a + b中的a + b。 结果应为(a + b) + (a + b),但如果我们先将a替换为a + b,我们将获得(a + b) + b,然后替换b,我们&# 39;得到(a + (a + b)) + b这是错误的结果!所以同步很重要:

class NameTransformer(ast.NodeTransformer):
    def __init__(self, names):
        self.names = names

    def visit_Name(self, node):
        if node.id in self.names:
            return self.names[node.id]
        else:
            return node

def substitute(expr, names):
    print "substitute"
    for varName, varValue in names.iteritems():
        print "  name " + varName + " for " + toSource(varValue)
    print "  in " + toSource(expr)
    return NameTransformer(names).visit(expr)

然后我们编写类似的NodeTransformer来查找调用,我们可以内联函数定义

class CallTransformer(ast.NodeTransformer):
    def __init__(self, fnName, varNames, fnExpr):
        self.fnName = fnName
        self.varNames = varNames
        # substitute in new fn expr for each CallTransformer
        self.fnExpr = copy.deepcopy(fnExpr)
        self.modified = False

    def visit_Call(self, node):
        if (node.func.id == self.fnName):
            if len(node.args) == len(self.varNames):
                print "expand call to " + self.fnName + "(" + (", ".join(self.varNames)) + ")" + " with arguments "+ ", ".join(map(toSource, node.args))
                # We substitute in args too!
                old_node = node
                args = map(self.visit, node.args)
                names = dict(zip(self.varNames, args))
                node = substitute(self.fnExpr, names)
                self.modified = True
                return node
            else:
                raise Exception("invalid arity " + toSource(node))
        else:
            return self.generic_visit(node)

def substituteCalls(expr, definitions, n = 3):
    while True:
        if (n <= 0):
            break
        n -= 1

        modified = False
        for fnName, varNames, fnExpr in definitions:
            transformer = CallTransformer(fnName, varNames, fnExpr)
            expr = transformer.visit(expr)
            modified = modified or transformer.modified

        if not modified:
            break

    return expr

substituteCalls是递归的,所以我们也可以内联递归函数。还有一个明确的限制,因为某些定义可能是无限递归的(如下面的fact)。有一些丑陋的复制,但需要分开不同的子树。


示例代码:

if True:
    print "f1 first, unique variable names"
    ex = parseExpr("a+b+f1(f2(x, y), x)")
    ex = substituteCalls(ex, [
        ("f1", ["u", "v"], parseExpr("sin(u + v)")),
        ("f2", ["i", "j"], parseExpr("i + j ^ 2"))])
    print toSource(ex)
    print "---"

if True:
    print "f1 first"
    ex = parseExpr("a+b+f1(f2(x, y), x)")
    ex = substituteCalls(ex, [
        ("f1", ["x", "y"], parseExpr("sin(x + y)")),
        ("f2", ["x", "y"], parseExpr("x + y ^ 2"))])
    print toSource(ex)
    print "---"

if True:
    print "f2 first"
    ex = parseExpr("f1(f1(x, x), y)")
    ex = substituteCalls(ex, [
        ("f1", ["x", "y"], parseExpr("x + y"))])
    print toSource(ex)
    print "---"

if True:
    print "fact"
    ex = parseExpr("fact(n)")
    ex = substituteCalls(ex, [
        ("fact", ["n"], parseExpr("n if n == 0 else n * fact(n-1)"))])
    print toSource(ex)
    print "---"

打印出来:

f1 first, unique variable names
expand call to f1(u, v) with arguments f2(x, y), x
substitute
  name u for f2(x, y)
  name v for x
  in sin((u + v))
expand call to f2(i, j) with arguments x, y
substitute
  name i for x
  name j for y
  in ((i + j) ^ 2)
((a + b) + sin((((x + y) ^ 2) + x)))
---
f1 first
expand call to f1(x, y) with arguments f2(x, y), x
substitute
  name y for x
  name x for f2(x, y)
  in sin((x + y))
expand call to f2(x, y) with arguments x, y
substitute
  name y for y
  name x for x
  in ((x + y) ^ 2)
((a + b) + sin((((x + y) ^ 2) + x)))
---
f2 first
expand call to f1(x, y) with arguments f1(x, x), y
expand call to f1(x, y) with arguments x, x
substitute
  name y for x
  name x for x
  in (x + y)
substitute
  name y for y
  name x for (x + x)
  in (x + x)
((x + x) + ((x + x) + x))
---
fact
expand call to fact(n) with arguments n
substitute
  name n for n
  in n if (n == 0) else (n * fact((n - 1)))
expand call to fact(n) with arguments (n - 1)
substitute
  name n for (n - 1)
  in n if (n == 0) else (n * fact((n - 1)))
expand call to fact(n) with arguments ((n - 1) - 1)
substitute
  name n for ((n - 1) - 1)
  in n if (n == 0) else (n * fact((n - 1)))
n if (n == 0) else (n * (n - 1) if ((n - 1) == 0) else ((n - 1) * ((n - 1) - 1) if (((n - 1) - 1) == 0) else (((n - 1) - 1) * fact((((n - 1) - 1) - 1)))))

codegen中的pypi版本很不幸。它没有正确地表达表达式,即使AST说他们应该这样做。我使用了jbremer/codegenpip install git+git://github.com/jbremer/codegen)。它也增加了不必要的括号,但它总比没有好。感谢@XavierCombelle的提示。


如果你有匿名函数,那么替换会变得更加棘手,即lambda。然后你需要重命名变量。您可以尝试使用替换实现搜索 lambda calculus 。然而,我找不到任何使用Python来完成任务的文章。

答案 2 :(得分:2)

您事先知道变量吗?

我建议使用SymPy!

举例如下:

import sympy

a,b,x,y = sympy.symbols('a b x y')
f1 = sympy.Function('f1')
f2 = sympy.Function('f2')

readString = "a+b+f1(f2(x,y),x)"

z = eval(readString)

'z'现在将是表示数学公式的符号术语。你可以打印出来。然后,您可以使用subs替换符号术语或函数。您可以再次象征性地表示正弦(例如f1f2),也可以使用sin()中的sympy.mpmath

根据您的需要,这种方法很棒,因为您最终可以计算,评估或简化此表达式。

答案 3 :(得分:1)

(使用sympy作为adrianX建议使用一些额外的代码。)

下面的代码在给定函数组合后将给定字符串转换为新字符串。它很仓促,记录不清,但可行


警告!

包含exec eval,如果外部用户提供输入,则恶意代码可能会生效。


更新:

  • 重写了整个代码。适用于Python 2.7。
  • 函数参数可以用逗号或空格或两者分隔。
  • 所有问题和评论的例子都有效。

import re
import sympy


##################################################
# Input string and functions

initial_str = 'a1+myf1(myf2(a, b),y)'
given_functions = {'myf1(x,y)': 'cross(x,y)', 'myf2(a, b)': 'value(a,b)'}
##################################################


print '\nEXECUTED/EVALUATED STUFF:\n'


processed_str = initial_str


def fixed_power_op(str_to_fix):
    return str_to_fix.replace('^', '**')


def fixed_multiplication(str_to_fix):
    """
    Inserts multiplication symbol wherever omitted.
    """

    pattern_digit_x = r"(\d)([A-Za-z])"         # 4x -> 4*x
    pattern_par_digit = r"(\))(\d)"             # )4 -> )*4
    pattern_digit_par = r"[^a-zA-Z]?_?(\d)(\()"  # 4( -> 4*(

    for patt in (pattern_digit_x, pattern_par_digit, pattern_digit_par):
        str_to_fix = re.sub(patt, r'\1*\2', str_to_fix)

    return str_to_fix


processed_str = fixed_power_op(processed_str)


class FProcessing(object):

    def __init__(self, func_key, func_body):
        self.func_key = func_key
        self.func_body = func_body

    def sliced_func_name(self):
        return re.sub(r'(.+)\(.+', r'\1', self.func_key)

    def sliced_func_args(self):
        return re.search(r'\((.*)\)', self.func_key).group()

    def sliced_args(self):
        """
        Returns arguments found for given function. Arguments can be separated by comma or whitespace.

        :returns (list)
        """

        if ',' in self.sliced_func_args():
            arg_separator = ','
        else:
            arg_separator = ' '

        return self.sliced_func_args().replace('(', '').replace(')', '').split(arg_separator)

    def num_of_sliced_args(self):
        """
        Returns number of arguments found for given function.
        """
        return len(self.sliced_args())

    def functions_in_function_body(self):
        """
        Detects functions in function body.

        e.g. f1(x,y): sin(x+y**2), will result in "sin"

        :returns (set)
        """

        return set(re.findall(r'([a-zA-Z]+_?\w*)\(', self.func_body))

    def symbols_in_func_body(self):
        """
        Detects non argument symbols in function body.
        """

        symbols_in_body = set(re.findall(r'[a-zA-Z]+_\w*', self.func_body))

        return symbols_in_body - self.functions_in_function_body()


# --------------------------------------------------------------------------------------
# SYMBOL DETECTION (x, y, z, mz,..)


# Prohibited symbols
prohibited_symbol_names = set()
# Custom function names are prohibited symbol names.
for key in given_functions.keys():
    prohibited_symbol_names |= {FProcessing(func_key=key, func_body=None).sliced_func_name()}


def symbols_in_str(provided_str):

    """
    Returns a set of symbol names that are contained in provided string.

    Allowed symbols start with a letter followed by 0 or more letters,
    and then 0 or more numbers (eg. x, x1, Na, Xaa_sd, xa123)
    """
    symbol_pattern = re.compile(r'[A-Za-z]+\d*')
    symbol_name_set = re.findall(symbol_pattern, provided_str)
    # Filters out prohibited.
    symbol_name_set = {i for i in symbol_name_set if (i not in prohibited_symbol_names)}

    return symbol_name_set


# ----------------------------------------------------------------
# EXEC SYMBOLS
symbols_in_given_str = symbols_in_str(initial_str)
# e.g. " x, y, sd = sympy.symbols('x y sd') "
symbol_string_to_exec = ', '.join(symbols_in_given_str)
symbol_string_to_exec += ' = '
symbol_string_to_exec += "sympy.symbols('%s')" % ' '.join(symbols_in_given_str)

exec symbol_string_to_exec


# -----------------------------------------------------------------------------------------
# FUNCTIONS

# Detects secondary functions (functions contained in body of given_functions dict)
sec_functions = set()
for key, val in given_functions.items():
    sec_functions |= FProcessing(func_key=key, func_body=val).functions_in_function_body()


def secondary_function_as_exec_str(func_key):
    """
    Used for functions that are contained in the function body of given_functions.

    E.g.  given_functions = {f1(x): sin(4+x)}

    "my_f1 = sympy.Function('sin')(x)"

    :param func_key: (str)
    :return: (str)
    """

    returned_str = "%s = sympy.Function('%s')" % (func_key, func_key)

    print returned_str
    return returned_str


def given_function_as_sympy_class_as_str(func_key, func_body):
    """
    Converts given_function to sympy class and executes it.

    E.g.    class f1(sympy.Function):
                nargs = (1, 2)

                @classmethod
                def eval(cls, x, y):
                    return cross(x+y**2)

    :param func_key: (str)
    :return: (None)
    """

    func_proc_instance = FProcessing(func_key=func_key, func_body=func_body)

    returned_str = 'class %s(sympy.Function): ' % func_proc_instance.sliced_func_name()
    returned_str += '\n\tnargs = %s' % func_proc_instance.num_of_sliced_args()
    returned_str += '\n\t@classmethod'

    returned_str += '\n\tdef eval(cls, %s):' % ','.join(func_proc_instance.sliced_args())
    returned_str = returned_str.replace("'", '')

    returned_str += '\n\t\treturn %s' % func_body

    returned_str = fixed_power_op(returned_str)

    print '\n', returned_str
    return returned_str


# Executes functions in given_functions' body
for name in sec_functions:
    exec secondary_function_as_exec_str(func_key=name)

# Executes given_functions
for key, val in given_functions.items():
    exec given_function_as_sympy_class_as_str(func_key=key, func_body=val)


final_result = eval(initial_str)


# PRINTING
print '\n' + ('-'*40)
print '\nRESULTS'

print '\nInitial string: \n%s' % initial_str

print '\nGiven functions:'
for key, val in given_functions.iteritems():
    print '%s: ' % key, val

print '\nResult: \n%s' % final_result

答案 4 :(得分:0)

您的长期目标是什么?是评估功能还是简单地执行替换?在前一种情况下,您可以尝试这样做(请注意,f1f2也可以动态定义):

import math
math.sin

def f2(x, y):
    return x + y ** 2

def f1(x, y):
    return math.sin(x + y)

a, b = 1, 2
x, y = 3, 4
eval('a + b + f1(f2(x, y), x)')
# 2.991148690709596

如果要替换这些函数并取回修改后的版本,则必须使用某种AST解析器。使用eval时要小心,因为这会为恶意用户输入代码打开一个安全漏洞。

答案 5 :(得分:0)

我认为你想使用像PyBison这样的解析器生成器。

查看包含您需要的基本代码的示例:

http://freenet.mcnabhosting.com/python/pybison/calc.py

您需要为函数添加令牌类型,为函数添加规则,然后在遇到函数时对该函数执行该操作。

如果您需要有关解析等的其他信息,请尝试阅读Lex和(Yacc或Bison)的一些基本教程。