在TensorFlow中,getter的概念和用途是什么?
tf.get_variable()
的签名是:
get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None
)
custom_getter
的定义在文档中给出如下:
custom_getter: Callable将第一个参数作为true getter,并允许覆盖内部get_variable方法。该 custom_getter的签名应该与此方法的签名匹配,但是 大多数面向未来的版本将允许更改:def custom_getter(getter,* args,** kwargs)。直接访问所有人 get_variable参数也是允许的:def custom_getter(getter, name,* args,** kwargs)。简单的身份自定义getter 使用修改后的名称创建变量是:python def custom_getter(getter,name,* args,** kwargs):return getter(name + ' _suffix',* args,** kwargs)
不幸的是,它不是很清楚。有人可以扩展它吗?
答案 0 :(得分:1)
以this custom getter为例,它根据可变的形状大小重写<uses-permission android:name="android.permission.INTERNET" />
。
答案 1 :(得分:1)
这个想法是它为变量创建过程提供了一个“钩子”,这是一种在创建变量时可能覆盖某些东西的方法。对于特定(狭隘)的问题,这可能非常方便。
从概念上讲,“自定义getter”类似于Python decorator:你编写一个函数来获取原始函数及其参数作为参数,除了你必须返回结果而不是修改函数。
您还可以在tf.variable_scope()
中将自定义getter指定为参数,允许您一次将其应用于多个变量。
作为一个有点武断的例子,让我们假设你有大量的代码来创建某种类型的网络。你有一天早上醒来想要尝试L2归一化所有变量可能有助于网络的性能。您可以将整个网络封装在范围内并执行以下操作(伪代码),而不是编辑所有已创建的图层及其中的变量:
with tf.variable_scope( "L2", custom_getter =
lambda getter, name, shape, *args, **kwargs:
tf.nn.l2_normalize( getter( name = name, shape = shape, *args, **kwargs ) ) ):
# the original network here
这将自动L2规范化网络中的所有变量,并保持这样。当然,如果您不想对所有这些执行此操作,则可以编写更多代码并按名称或其他参数进行筛选。您也不必将此作为lambda函数执行,您可以使用def
编写一个普通函数,并将其作为custom_getter
参数传递。
另请注意,这个简单的示例提交了一个罪:返回的变量将是 tensor 而不是变量,因此tf.assign()
将在此变量上失败,例如。
作为参考,默认的getter在tensorflow/python/ops/variable_scope.py
类_VariableStore
方法at this exact link for r1.8中的get_variable()
中定义。