TensorFlow中{get}的吸气概念

时间:2017-05-28 22:18:00

标签: tensorflow

在TensorFlow中,getter的概念和用途是什么?

tf.get_variable()的签名是:

get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=True,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None
)

custom_getter的定义在文档中给出如下:

  

custom_getter: Callable将第一个参数作为true   getter,并允许覆盖内部get_variable方法。该   custom_getter的签名应该与此方法的签名匹配,但是   大多数面向未来的版本将允许更改:def   custom_getter(getter,* args,** kwargs)。直接访问所有人   get_variable参数也是允许的:def custom_getter(getter,   name,* args,** kwargs)。简单的身份自定义getter   使用修改后的名称创建变量是:python def   custom_getter(getter,name,* args,** kwargs):return getter(name +   ' _suffix',* args,** kwargs)

不幸的是,它不是很清楚。有人可以扩展它吗?

2 个答案:

答案 0 :(得分:1)

this custom getter为例,它根据可变的形状大小重写<uses-permission android:name="android.permission.INTERNET" />

答案 1 :(得分:1)

这个想法是它为变量创建过程提供了一个“钩子”,这是一种在创建变量时可能覆盖某些东西的方法。对于特定(狭隘)的问题,这可能非常方便。

从概念上讲,“自定义getter”类似于Python decorator:你编写一个函数来获取原始函数及其参数作为参数,除了你必须返回结果而不是修改函数。

您还可以在tf.variable_scope()中将自定义getter指定为参数,允许您一次将其应用于多个变量。

作为一个有点武断的例子,让我们假设你有大量的代码来创建某种类型的网络。你有一天早上醒来想要尝试L2归一化所有变量可能有助于网络的性能。您可以将整个网络封装在范围内并执行以下操作(伪代码),而不是编辑所有已创建的图层及其中的变量:

with tf.variable_scope( "L2", custom_getter =
    lambda getter, name, shape, *args, **kwargs:
        tf.nn.l2_normalize( getter( name = name, shape = shape, *args, **kwargs ) ) ):
    # the original network here

这将自动L2规范化网络中的所有变量,并保持这样。当然,如果您不想对所有这些执行此操作,则可以编写更多代码并按名称或其他参数进行筛选。您也不必将此作为lambda函数执行,您可以使用def编写一个普通函数,并将其作为custom_getter参数传递。

另请注意,这个简单的示例提交了一个罪:返回的变量将是 tensor 而不是变量,因此tf.assign()将在此变量上失败,例如。

作为参考,默认的getter在tensorflow/python/ops/variable_scope.py_VariableStore方法at this exact link for r1.8中的get_variable()中定义。