我目前正在玩抽象语法树,使用ast和astor模块。该文档教我如何检索和漂亮打印各种功能的源代码,网上的各种示例显示如何通过将一行的内容替换为另一行或更改所有出现的+到*来修改部分代码。 / p>
但是,我想在各个地方插入其他代码,特别是当函数调用另一个函数时。例如,以下假设函数:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
会产生(一旦我们使用astor.to_source(modified_ast)
):
def some_function(param):
if param == 0:
print ("Hey, we're calling case_0")
return case_0(param)
elif param < 0:
print ("Hey, we're calling negative_case")
return negative_case(param)
print ("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
这是否可以使用抽象语法树? (注意:我知道在运行代码时,调用的装饰函数会产生相同的结果,但这不是我所追求的;我需要实际输出修改后的代码,并插入比print语句更复杂的内容)。
答案 0 :(得分:2)
从您的问题中不清楚您是否询问如何将节点插入到低级别的AST树中,或者更具体地说是如何使用更高级别的工具进行节点插入以遍历AST树(例如, ast.NodeVisitor
或astor.TreeWalk
)的子类。
以低级别插入节点非常容易。您只需在树中的适当列表中使用list.insert
即可。例如,这里有一些代码可以添加你想要的三个print
调用中的最后一个(另外两个几乎一样容易,他们只需要更多的索引)。大多数代码都是为打印调用构建新的AST节点。实际插入非常短:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source) # parse an ast tree from the source code
# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)
tree.body[0].body.insert(1, print_statement) # doing the actual insert here!
# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))
输出将是:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
print("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
(注意ast.Call
的参数在Python 3.4和3.5+之间发生了变化。如果您使用的是旧版本的Python,则可能需要添加两个额外的None
参数:{{1 }})
如果你正在使用更高级别的方法,事情会有点棘手,因为代码需要找出插入新节点的位置,而不是使用你自己的输入知识来硬编码。
这是ast.Call(print_func, [message], [], None, None)
子类的快速而又脏的实现,它在任何具有TreeWalk
节点的语句之前添加打印调用作为语句。请注意,Call
个节点包括对类的调用(创建实例),而不仅仅是函数调用。此代码仅处理嵌套调用的最外层,因此如果代码为Call
,则插入的foo(bar())
将仅提及print
:
foo
你会这样称呼它:
class PrintBeforeCall(astor.TreeWalk):
def pre_body_name(self):
body = self.cur_node
print_func = ast.Name("print", ast.Load())
for i, child in enumerate(body[:]):
self.__name = None
self.walk(child)
if self.__name is not None:
message = ast.Str("Calling {}".format(self.__name))
print_statement = ast.Expr(ast.Call(print_func, [message], []))
body.insert(i, print_statement)
self.__name = None
return True
def pre_Call(self):
self.__name = self.cur_node.func.id
return True
这次的输出是:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source)
walker = PrintBeforeCall() # create an instance of the TreeWalk subclass
walker.walk(tree) # modify the tree in place
print(astor.to_source(tree)
这不是你想要的确切信息,但它很接近。 walker无法详细描述正在处理的案例,因为它只查看被调用的名称函数,而不是查看它的条件。如果你有一套非常明确的东西需要寻找,你或许可以改变它来查看def some_function(param):
if param == 0:
print('Calling case_0')
return case_0(param)
elif param < 0:
print('Calling negative_case')
return negative_case(param)
print('Calling all_other_cases')
return all_other_cases(param)
节点,但我怀疑这会更具挑战性。