我的问题是用ast重新定义+
运算符来评估表达式。
我有一个表达式列表,可以使用eval()轻松解决:
>>> expr = '1+2*3**4/5'
>>> print(eval(expr))
33.4
但是我想为列表和字典重新定义+
运算符(加法符),如下所示:
expr = '[1,2,3]+[4,5,6]'
常规结果为
[1、2、3、4、5、6]
但我想拥有
[5,7,9]
就像是R语言一样。
类似的字典也应如此:
expr = "{'a':1, 'b':2} + {'a':3, 'b':4}"
我想拥有
{'a':4,'b':6}
简而言之,我认为是要替换普通的add函数,即当操作数是列表或命令正确的操作时。
我尝试使用ast
和NodeTransformer
,但没有成功。有人可以帮助我吗?
答案 0 :(得分:5)
制作自己的列表类并在其上定义加法运算符:
class MyKindOfList(list):
def __add__(self, other):
return MyKindOfList(a + b for a, b in zip(self, other))
然后您可以执行以下操作:
x = MyKindOfList([1, 2, 3])
y = MyKindOfList([4, 5, 6])
print (x + y) # prints [5, 7, 9]
答案 1 :(得分:1)
从Aran-Fey的建议开始,然后从this link读到一些东西,我写了一个更具可读性的代码来解决问题
import ast
from itertools import zip_longest
def __custom_add(lhs, rhs):
if isinstance(lhs,list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip_longest(lhs, rhs, fillvalue=0)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key,0), rhs.get(key,0)) for key in keys}
return lhs + rhs
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
new_node = ast.Call(func=ast.Name(id='__custom_add', ctx=ast.Load()),
args=[node.left, node.right],
keywords = [],
starargs = None, kwargs= None
)
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
return new_node
return node
expr = [
'(2 + 3 * 4)/2',
'[1, 2] + [3, 4]',
"{'a': 1} + {'a': -2}"
]
for e in expr:
syntax_tree = ast.parse(e, mode='eval')
syntax_tree = SumTransformer().visit(syntax_tree)
res = eval(compile(syntax_tree, '<ast>', 'eval'))
print(res)
# results
# 7.0
# [4, 6]
# {'a': -1}
感谢所有人的帮助
答案 2 :(得分:0)
即使使用__add__
模块,也不能重载内置类的list
方法(例如dict
和ast
)。但是,您可以将所有x + y
之类的附加内容重写为your_custom_addition_function(x, y)
之类的函数调用。
从本质上讲,这是一个三步过程:
ast.parse
解析输入表达式。NodeTransformer
重写所有对函数调用的添加。import ast
def overload_add(syntax_tree):
# rewrite all additions to calls to our addition function
class SumTransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if not isinstance(node.op, ast.Add):
node.left = lhs
node.right = rhs
return node
name = ast.Name('__custom_add', ast.Load())
args = [lhs, rhs]
kwargs = []
return ast.Call(name, args, kwargs)
syntax_tree = SumTransformer().visit(syntax_tree)
syntax_tree = ast.fix_missing_locations(syntax_tree)
# inject the custom addition function into the sytnax tree
code = '''
def __custom_add(lhs, rhs):
if isinstance(lhs, list) and isinstance(rhs, list):
return [__custom_add(l, r) for l, r in zip(lhs, rhs)]
if isinstance(lhs, dict) and isinstance(rhs, dict):
keys = lhs.keys() | rhs.keys()
return {key: __custom_add(lhs.get(key, 0), rhs.get(key, 0)) for key in keys}
return lhs + rhs
'''
add_func = ast.parse(code).body[0]
syntax_tree.body.insert(0, add_func)
return syntax_tree
code = '''
print(1 + 2)
print([1, 2] + [3, 4])
print({'a': 1} + {'a': -2})
'''
syntax_tree = ast.parse(code)
syntax_tree = overload_add(syntax_tree)
codeobj = compile(syntax_tree, 'foo.py', 'exec')
exec(codeobj)
# output:
# 3
# [4, 6]
# {'a': -1}
__custom_add
添加到全局作用域-它可以像其他任何全局函数一样访问,并且可能会被覆盖,阴影,删除或以其他方式篡改。