Python:如何使用ast重载运算符

时间:2018-08-19 11:56:24

标签: python overloading abstract-syntax-tree

我的问题是用ast重新定义+运算符来评估表达式。 我有一个表达式列表,可以使用eval()轻松解决:

>>> expr = '1+2*3**4/5'
>>> print(eval(expr))
33.4

但是我想为列表和字典重新定义+运算符(加法符),如下所示:

expr = '[1,2,3]+[4,5,6]'

常规结果为

  

[1、2、3、4、5、6]

但我想拥有

  

[5,7,9]

就像是R语言一样。

类似的字典也应如此:

expr = "{'a':1, 'b':2} + {'a':3, 'b':4}"

我想拥有

  

{'a':4,'b':6}

简而言之,我认为是要替换普通的add函数,即当操作数是列表或命令正确的操作时。

我尝试使用astNodeTransformer,但没有成功。有人可以帮助我吗?

3 个答案:

答案 0 :(得分:5)

制作自己的列表类并在其上定义加法运算符:

class MyKindOfList(list):
    def __add__(self, other):
        return MyKindOfList(a + b for a, b in zip(self, other))

然后您可以执行以下操作:

x = MyKindOfList([1, 2, 3])
y = MyKindOfList([4, 5, 6])

print (x + y)  # prints [5, 7, 9]

答案 1 :(得分:1)

从Aran-Fey的建议开始,然后从this link读到一些东西,我写了一个更具可读性的代码来解决问题

import ast
from itertools import zip_longest

def __custom_add(lhs, rhs):
    if isinstance(lhs,list) and isinstance(rhs, list):
        return [__custom_add(l, r) for l, r in zip_longest(lhs, rhs, fillvalue=0)]

    if isinstance(lhs, dict) and isinstance(rhs, dict):
        keys = lhs.keys() | rhs.keys()
        return {key: __custom_add(lhs.get(key,0), rhs.get(key,0)) for key in keys}

    return lhs + rhs

class SumTransformer(ast.NodeTransformer):

    def visit_BinOp(self, node):
        if isinstance(node.op, ast.Add):
            new_node = ast.Call(func=ast.Name(id='__custom_add', ctx=ast.Load()),
                            args=[node.left, node.right],
                            keywords = [],
                            starargs = None, kwargs= None
                            )
            ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return new_node

        return node

expr = [
    '(2 + 3 * 4)/2',
    '[1, 2] + [3, 4]',
    "{'a': 1} + {'a': -2}"
    ]


for e in expr:
    syntax_tree = ast.parse(e, mode='eval')
    syntax_tree = SumTransformer().visit(syntax_tree)
    res = eval(compile(syntax_tree, '<ast>', 'eval'))
    print(res)

# results

# 7.0
# [4, 6]
# {'a': -1}

感谢所有人的帮助

答案 2 :(得分:0)

即使使用__add__模块,也不能重载内置类的list方法(例如dictast)。但是,您可以将所有x + y之类的附加内容重写为your_custom_addition_function(x, y)之类的函数调用。

从本质上讲,这是一个三步过程:

  1. 使用ast.parse解析输入表达式。
  2. 使用NodeTransformer重写所有对函数调用的添加。
  3. 解析您的自定义加法函数的源代码,并将其添加到步骤1中获得的抽象语法树中。

代码

import ast


def overload_add(syntax_tree):
    # rewrite all additions to calls to our addition function
    class SumTransformer(ast.NodeTransformer):
        def visit_BinOp(self, node):
            lhs = self.visit(node.left)
            rhs = self.visit(node.right)

            if not isinstance(node.op, ast.Add):
                node.left = lhs
                node.right = rhs
                return node

            name = ast.Name('__custom_add', ast.Load())
            args = [lhs, rhs]
            kwargs = []
            return ast.Call(name, args, kwargs)

    syntax_tree = SumTransformer().visit(syntax_tree)
    syntax_tree = ast.fix_missing_locations(syntax_tree)

    # inject the custom addition function into the sytnax tree
    code = '''
def __custom_add(lhs, rhs):
    if isinstance(lhs, list) and isinstance(rhs, list):
        return [__custom_add(l, r) for l, r in zip(lhs, rhs)]

    if isinstance(lhs, dict) and isinstance(rhs, dict):
        keys = lhs.keys() | rhs.keys()
        return {key: __custom_add(lhs.get(key, 0), rhs.get(key, 0)) for key in keys}

    return lhs + rhs
    '''
    add_func = ast.parse(code).body[0]
    syntax_tree.body.insert(0, add_func)

    return syntax_tree

code = '''
print(1 + 2)
print([1, 2] + [3, 4])
print({'a': 1} + {'a': -2})
'''
syntax_tree = ast.parse(code)
syntax_tree = overload_add(syntax_tree)
codeobj = compile(syntax_tree, 'foo.py', 'exec')
exec(codeobj)

# output:
# 3
# [4, 6]
# {'a': -1}

注意事项

  • 该加法函数将以名称__custom_add添加到全局作用域-它可以像其他任何全局函数一样访问,并且可能会被覆盖,阴影,删除或以其他方式篡改。