用装饰器理解这段Python代码

时间:2017-03-03 09:50:35

标签: python python-decorators

所以我试图从here解密一些代码。下面我复制并粘贴了我不太了解的相关代码。

def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not provided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated

class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        print(self.inputs)
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        print(self.layers)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable

    …

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, string_types):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

 ....

 # equivalent to max_pool = layer(max_pool)
    @layer
    def max_pool(self, inp, k_h, k_w, s_h, s_w, name, padding='SAME'):
        self.validate_padding(padding)
        return tf.nn.max_pool(inp,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

我理解上面的代码,虽然我在尝试理解下面的代码时遇到了一些麻烦:

class PNet(Network):
    def setup(self):
        (self.feed('data') 
             .conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1')
             .prelu(name='PReLU1')
             .max_pool(2, 2, 2, 2, name='pool1')
             .conv(3, 3, 16, 1, 1, padding='VALID', relu=False, name='conv2')
             .prelu(name='PReLU2')
             .conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv3')
             .prelu(name='PReLU3')
             .conv(1, 1, 2, 1, 1, relu=False, name='conv4-1')
             .softmax(3,name='prob1'))

        (self.feed('PReLU3') #pylint: disable=no-value-for-parameter
             .conv(1, 1, 4, 1, 1, relu=False, name='conv4-2'))

特别是,我对这部分代码工作感到困惑:

self.feed('data') 
             .conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1')
             .prelu(name='PReLU1')
             .max_pool(2, 2, 2, 2, name='pool1')
             .conv(3, 3, 16, 1, 1, padding='VALID', relu=False, name='conv2')
             .prelu(name='PReLU2')
             .conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv3')
             .prelu(name='PReLU3')
             .conv(1, 1, 2, 1, 1, relu=False, name='conv4-1')
             .softmax(3,name='prob1'))

它也可以写成: self.feed('data').conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1').prelu(name='PReLU1')...

这就是我不明白的地方,feed本身是Network类的一种方法,但我如何能够访问feed的方法等。< / p>

2 个答案:

答案 0 :(得分:1)

这与装饰者无关。

feed方法 - 以及可能是convprelu方法 - 返回self。这意味着您可以继续调用调用该方法的结果的方法。

这被称为&#34;方法链接&#34 ;;它在Ruby等语言中比较常见,但你也可以在Python中使用它。

答案 1 :(得分:0)

真的很简单。

因此,当您启动此链时,您将拥有对象self,您可以在其上调用方法feed()方法,如果您查看源代码,则返回self。但这是修改后的自我。因此,此时feed()被“消耗”,而您将留下类似(modified)self.prelu()....的内容。这用另一种方法重复。它会一直重复,直到没有电话为止。