使用python decorator自动替换函数参数默认值?

时间:2012-08-02 08:33:34

标签: python function arguments decorator

实际上标题并不能完全反映我想问的问题。我的目的是这样的:我正在使用matplotlib编写一些绘图函数。我有一系列用于不同绘图目的的功能。比如line的line_plot(),bar等的bar_plot(),例如:

import matplotlib.pyplot as plt
def line_plot(axes=None,x=None,y=None):
    if axes==None:
        fig=plt.figure()
        axes=fig.add_subplot(111)
    else:
        pass
    axes.plot(x,y)

def bar_plot(axes=None,x=None,y=None):
    if axes==None:
        fig=plt.figure()
        axes=fig.add_subplot(111)
    else:
        pass
    axes.bar(left=x,height=y)

然而问题是,对于定义的每个函数,我必须重复这部分代码:

    if axes==None:
        fig=plt.figure()
        axes=fig.add_subplot(111)
    else:
        pass

有没有像使用装饰器那样的方法,我可以在绘图函数的定义之前应用,它将自动执行代码的重复部分?因此,我不必每次都重复它们。

一个可能的选择是定义这样的函数:

def check_axes(axes):
    if axes==None:
        fig=plt.figure()
        axes=fig.add_subplot(111)
        return axes
    else:
        return axes

然后示例如下:

import matplotlib.pyplot as plt    
def line_plot(axes=None,x=None,y=None):
    axes=check_axes(axes)
    axes.plot(x,y)

def bar_plot(axes=None,x=None,y=None):
    axes=check_axes(axes)
    axes.bar(left=x,height=y)

但是有更好/更干净/更pythonic的方式吗?我想我可以使用装饰器,但没有想出来。任何人都可以提出一些想法吗?

谢谢!

1 个答案:

答案 0 :(得分:7)

以下是如何使用装饰器进行操作:

import matplotlib.pyplot as plt    

def check_axes(plot_fn):
    def _check_axes_wrapped_plot_fn(axes=None, x=None, y=None):
        if not axes:
            fig = plt.figure()
            axes = fig.add_subplot(111)
            return plot_fn(axes, x, y)
        else:
            return plot_fn(axes, x, y)
    return _check_axes_wrapped_plot_fn

@check_axes
def line_plot(axes, x=None, y=None):
    axes.plot(x, y)

@check_axes
def bar_plot(axes, x=None, y=None):
    axes.bar(left=x, height=y)

工作原理:@check_axes语法重新定义了修饰函数的名称,例如: line_plot是由装饰器创建的新函数,即_check_axes_wrapped_plot_fn。这个“包装”函数处理axes - 检查逻辑,然后调用原始绘图函数。

如果您希望check_axes能够装饰任何以axes作为其第一个参数的绘图函数,而不仅仅是那些仅采用xy的绘图函数参数,您可以使用Python方便的*语法来获取任意参数列表:

def check_axes(plot_fn):
    def _check_axes_wrapped_plot_fn(axes=None, *args):
        if not axes:
            fig = plt.figure()
            axes = fig.add_subplot(111)
            return plot_fn(axes, *args)  # pass all args after axes
        else:
            return plot_fn(axes, *args)  # pass all args after axes
    return _check_axes_wrapped_plot_fn  

现在,无论是“更好/更清洁/更多Pythonic”,这可能是一个争论的问题,并取决于更大的背景。

顺便说一下,本着“更多Pythonic”的精神,我重新格式化了你的代码,使其更接近PEP8风格指南。注意参数列表中逗号后面的空格,=赋值运算符周围的空格(但=用于函数关键字参数时没有),并且说not axes而不是axes == None