所以我试图从here解密一些代码。下面我复制并粘贴了我不太了解的相关代码。
def layer(op):
'''Decorator for composable network layers.'''
def layer_decorated(self, *args, **kwargs):
# Automatically set a name if not provided.
name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
# Figure out the layer inputs.
if len(self.terminals) == 0:
raise RuntimeError('No input variables found for layer %s.' % name)
elif len(self.terminals) == 1:
layer_input = self.terminals[0]
else:
layer_input = list(self.terminals)
# Perform the operation and get the output.
layer_output = op(self, layer_input, *args, **kwargs)
# Add to layer LUT.
self.layers[name] = layer_output
# This output is now the input for the next layer.
self.feed(layer_output)
# Return self for chained calls.
return self
return layer_decorated
class Network(object):
def __init__(self, inputs, trainable=True):
# The input nodes for this network
self.inputs = inputs
print(self.inputs)
# The current list of terminal nodes
self.terminals = []
# Mapping from layer names to layers
self.layers = dict(inputs)
print(self.layers)
# If true, the resulting variables are set as trainable
self.trainable = trainable
…
def feed(self, *args):
'''Set the input(s) for the next operation by replacing the terminal nodes.
The arguments can be either layer names or the actual layers.
'''
assert len(args) != 0
self.terminals = []
for fed_layer in args:
if isinstance(fed_layer, string_types):
try:
fed_layer = self.layers[fed_layer]
except KeyError:
raise KeyError('Unknown layer name fed: %s' % fed_layer)
self.terminals.append(fed_layer)
return self
....
# equivalent to max_pool = layer(max_pool)
@layer
def max_pool(self, inp, k_h, k_w, s_h, s_w, name, padding='SAME'):
self.validate_padding(padding)
return tf.nn.max_pool(inp,
ksize=[1, k_h, k_w, 1],
strides=[1, s_h, s_w, 1],
padding=padding,
name=name)
我理解上面的代码,虽然我在尝试理解下面的代码时遇到了一些麻烦:
class PNet(Network):
def setup(self):
(self.feed('data')
.conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1')
.prelu(name='PReLU1')
.max_pool(2, 2, 2, 2, name='pool1')
.conv(3, 3, 16, 1, 1, padding='VALID', relu=False, name='conv2')
.prelu(name='PReLU2')
.conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv3')
.prelu(name='PReLU3')
.conv(1, 1, 2, 1, 1, relu=False, name='conv4-1')
.softmax(3,name='prob1'))
(self.feed('PReLU3') #pylint: disable=no-value-for-parameter
.conv(1, 1, 4, 1, 1, relu=False, name='conv4-2'))
特别是,我对这部分代码工作感到困惑:
self.feed('data')
.conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1')
.prelu(name='PReLU1')
.max_pool(2, 2, 2, 2, name='pool1')
.conv(3, 3, 16, 1, 1, padding='VALID', relu=False, name='conv2')
.prelu(name='PReLU2')
.conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv3')
.prelu(name='PReLU3')
.conv(1, 1, 2, 1, 1, relu=False, name='conv4-1')
.softmax(3,name='prob1'))
它也可以写成:
self.feed('data').conv(3, 3, 10, 1, 1, padding='VALID', relu=False, name='conv1').prelu(name='PReLU1')...
这就是我不明白的地方,feed
本身是Network
类的一种方法,但我如何能够访问feed
的方法等。< / p>
答案 0 :(得分:1)
这与装饰者无关。
feed
方法 - 以及可能是conv
和prelu
方法 - 返回self
。这意味着您可以继续调用调用该方法的结果的方法。
这被称为&#34;方法链接&#34 ;;它在Ruby等语言中比较常见,但你也可以在Python中使用它。
答案 1 :(得分:0)
真的很简单。
因此,当您启动此链时,您将拥有对象self
,您可以在其上调用方法feed()
方法,如果您查看源代码,则返回self
。但这是修改后的自我。因此,此时feed()
被“消耗”,而您将留下类似(modified)self.prelu()....
的内容。这用另一种方法重复。它会一直重复,直到没有电话为止。