ast型节点评估

时间:2018-10-02 14:42:46

标签: python python-3.x abstract-syntax-tree

前言

有一个typed_ast library用于跨Python AST解析和处理(例如,在mypy project 1 中)。

问题

我想知道是否有一种方法可以像标准ast module一样编译节点?

因为这样有效

import ast

code = compile(ast.parse('print("Hello World!")'), '<ast>', 'exec')
eval(code)  # Hello World!

但是这个

from typed_ast import ast3

code = compile(ast3.parse('print("Hello World!")'), '<ast>', 'exec')  # raises exception
eval(code)

给我

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: compile() arg 1 must be a string, bytes or AST object

分析

我知道有a helper class可以在typed_ast.ast27typed_ast.ast3之间进行转换,但是找不到类似的typed_ast.ast3-> ast转换。

我也知道typed-astunparse package,但是它会以字符串形式创建源代码,这不是一个选择,因为我正在使用一些使AST compile可用但不会无法解析的黑客程序-可解析的。

最后有ast3.dump function个文档说的是

  

...如果需要评估,则* annotate_fields *必须设置为False ...

因此似乎有一种方法可以评估生成的转储字符串?或者也许有一种方法可以从ast加载此字符串?

还是我应该编写自己的ast3.NodeTransformer类来执行这种转换?


1 proof

1 个答案:

答案 0 :(得分:0)

到目前为止,我的解决方案使用自定义ast3.NodeTransformer(已在 Python3.5 上进行了测试)

import ast
from functools import partial
from itertools import chain

from typed_ast import ast3


def to_visitor(cls):
    def none(_):
        return None

    try:
        plain_cls = getattr(ast, cls.__name__)
    except AttributeError:
        # node type is not found in `ast` module, skipping
        return none

    def visit(self, node):
        node = self.generic_visit(node)
        result = plain_cls(*map(partial(getattr, node), plain_cls._fields))
        return ast3.copy_location(result, node)

    return visit


def to_subclasses(cls,
                  *,
                  deep=True):
    result = cls.__subclasses__()
    yield from result
    if not deep:
        return
    subclasses_factory = partial(to_subclasses,
                                 deep=deep)
    yield from chain.from_iterable(map(subclasses_factory, result))


class TypedToPlain(ast3.NodeTransformer):
    visitors = {'visit_' + cls.__name__: to_visitor(cls)
                for cls in set(to_subclasses(ast3.AST))}

    def __getattr__(self, name):
        return partial(self.visitors[name], self)

    def generic_visit(self, node):
        for field, old_value in ast3.iter_fields(node):
            if isinstance(old_value, list):
                new_values = []
                for value in old_value:
                    if isinstance(value, ast3.AST):
                        value = self.visit(value)
                        if value is None:
                            continue
                        elif not isinstance(value, ast.AST):
                            new_values.extend(value)
                            continue
                    new_values.append(value)
                old_value[:] = new_values
            elif isinstance(old_value, ast3.AST):
                new_node = self.visit(old_value)
                if new_node is None:
                    delattr(node, field)
                else:
                    setattr(node, field, new_node)
        return node

测试

from typed_ast import ast3

code = compile(TypedToPlain().visit(ast3.parse('print("Hello World!")')),
               '<ast>', 'exec')
eval(code)  # Hello World!