将变量添加到闭包的装饰器

时间:2014-11-10 09:24:23

标签: python closures decorator

我想编写一个将自定义局部变量注入函数的装饰器。

界面可能会这样。

def enclose(name, value):
    ...
    def decorator(func):
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)
        return wrapper
    return decorator

期望:

@enclose('param1', 1)
def f():
   param1 += 1
   print param1
f() will compile and run without error

输出:

2

可以在python中执行此操作吗?为什么呢?

2 个答案:

答案 0 :(得分:2)

我以为我会尝试这一点,看看会有多难。结果很难。

首先,你是如何实现这一点的?额外parameter是注入的局部变量,函数的附加参数还是非局部变量。注入的局部变量每次都是一个新的对象,但是如何创建更复杂的对象......另一个参数将记录对象的突变,但是在函数调用之间将忘记对名称的赋值。另外,这需要解析源以找到放置参数的位置,或者直接操作代码对象。最后,声明非局部变量将记录对象的变异和名称的赋值。有效地,非本地是全局的,但只能通过修饰函数到达。同样,使用非本地需要解析源并找到放置非局部声明的位置或直接操作代码对象。

最后我决定使用非局部变量并解析函数源。最初我打算操纵代码对象,但它看起来太复杂了。

以下是装饰器的代码:

import re
import types
import inspect


class DummyInject:

    def __call__(self, **kwargs):
        return lambda func: func

    def __getattr__(self, name):
        return self


class Inject:

    function_end = re.compile(r"\)\s*:\s*\n")
    indent = re.compile("\s+")
    decorator = re.compile("@([a-zA-Z0-9_]+)[.a-zA-Z0-9_]*")
    exec_source = """
def create_new_func({closure_names}):
{func_source}
{indent}return {func_name}"""
    nonlocal_declaration = "{indent}nonlocal {closure_names};"

    def __init__(self, **closure_vars):
        self.closure_vars = closure_vars

    def __call__(self, func):

        lines, line_number = inspect.getsourcelines(func)
        self.inject_nonlocal_declaration(lines)
        new_func = self.create_new_function(lines, func)
        return new_func 

    def inject_nonlocal_declaration(self, lines):
        """hides nonlocal declaration in first line of function."""
        function_body_start = self.get_function_body_start(lines)
        nonlocals = self.nonlocal_declaration.format(
            indent=self.indent.match(lines[function_body_start]).group(),
            closure_names=", ".join(self.closure_vars)
        )
        lines[function_body_start] = nonlocals + lines[function_body_start]
        return lines

    def get_function_body_start(self, lines):
        line_iter = enumerate(lines)

        found_function_header = False
        for i, line in line_iter:
            if self.function_end.search(line):
                found_function_header = True
                break
        assert found_function_header

        for i, line in line_iter:
            if not line.strip().startswith("#"):
                break

        return i

    def create_new_function(self, lines, func):
        # prepares source -- eg. making sure indenting is correct
        declaration_indent, body_indent = self.get_indent(lines)
        if not declaration_indent:
            lines = [body_indent + line for line in lines]
        exec_code = self.exec_source.format(
            closure_names=", ".join(self.closure_vars),
            func_source="".join(lines),
            indent=declaration_indent if declaration_indent else body_indent,
            func_name=func.__name__
        )

        # create new func -- mainly only want code object contained by new func
        lvars = {"closure_vars": self.closure_vars}
        gvars = self.get_decorators(exec_code, func.__globals__)
        exec(exec_code, gvars, lvars)
        new_func = eval("create_new_func(**closure_vars)", gvars, lvars)

        # add back bits that enable function to work well
        # includes original global references and 
        new_func = self.readd_old_references(new_func, func)
        return new_func

    def readd_old_references(self, new_func, old_func):
        """Adds back globals, function name and source reference."""
        func = types.FunctionType(
            code=self.add_src_ref(new_func.__code__, old_func.__code__),
            globals=old_func.__globals__,
            name=old_func.__name__,
            argdefs=old_func.__defaults__,
            closure=new_func.__closure__
        )
        func.__doc__ = old_func.__doc__
        return func

    def add_src_ref(self, new_code, old_code):
        return types.CodeType(
            new_code.co_argcount,
            new_code.co_kwonlyargcount,
            new_code.co_nlocals,
            new_code.co_stacksize,
            new_code.co_flags,
            new_code.co_code,
            new_code.co_consts,
            new_code.co_names,
            new_code.co_varnames,
            old_code.co_filename, # reuse filename
            new_code.co_name,
            old_code.co_firstlineno, # reuse line number
            new_code.co_lnotab,
            new_code.co_freevars,
            new_code.co_cellvars
        )

    def get_decorators(self, source, global_vars):
        """Creates a namespace for exec function creation in. Must remove
        any reference to Inject decorator to prevent infinite recursion."""
        namespace = {}
        for match in self.decorator.finditer(source):
            decorator = eval(match.group()[1:], global_vars)
            basename = match.group(1)
            if decorator is Inject:
                namespace[basename] = DummyInject()
            else:
                namespace[basename] = global_vars[basename]
        return namespace

    def get_indent(self, lines):
        """Takes a set of lines used to create a function and returns the 
        outer indentation that the function is declared in and the inner
        indentation of the body of the function.""" 
        body_indent = None
        function_body_start = self.get_function_body_start(lines)
        for line in lines[function_body_start:]:
            match = self.indent.match(line)
            if match:
                body_indent = match.group()
                break
        assert body_indent

        match = self.indent.match(lines[0])
        if not match:
            declaration_indent = ""
        else:
            declaration_indent = match.group()

        return declaration_indent, body_indent


if __name__ == "__main__":  

    a = 1

    @Inject(b=10)
    def f(c, d=1000):
        "f uses injected variables"
        return a + b + c + d

    @Inject(var=None)
    def g():
        """Purposefully generate exception to show stacktraces are still
        meaningful."""
        create_name_error # line number 164

    print(f(100)) # prints 1111
    assert f(100) == 1111
    assert f.__doc__ == "f uses injected variables" # show doc is retained

    try:
        g()
    except NameError:
        raise 
    else:
        assert False
    # stack trace shows NameError on line 164

其中输出以下内容:

1111
Traceback (most recent call last):
  File "inject.py", line 171, in <module>
    g()
  File "inject.py", line 164, in g
    create_name_error # line number 164
NameError: name 'create_name_error' is not defined

整件事情是丑陋的,但它确实有效。值得注意的是,如果Inject用于方法,那么任何注入的值都会在类的所有实例之间共享。

答案 1 :(得分:0)

你可以使用全局变量来做,但我不推荐这种方法。

def enclose(name, value):
    globals()[name] = value

    def decorator(func):
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)
        return wrapper

    return decorator

@enclose('param1', 1)
def f():
    global param1
    param1 += 1

    print(param1)

f()