如何在急切的执行模式中重用tensorflow变量?

时间:2018-06-14 08:30:19

标签: python tensorflow

在tensorflow中调用get_variable()函数时,"重用"的行为。 flag在tensorflow api doc中定义为AUTO_REUSE:

  

重用:True,None或tf.AUTO_REUSE; ...... 急切执行时   启用后,此参数始终强制为tf.AUTO_REUSE

然而,当我真正按照网页上的建议运行演示代码时:

tf.enable_eager_execution()
def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v
v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

失败了。 (如果第一行被删除,它会通过,如预期的那样。)

那么如何在急切模式下重用变量?这是一个错误还是我错过了什么?

3 个答案:

答案 0 :(得分:4)

在急切的模式下,事情变得更简单......除了那些因图形模型使用太久而受到脑损伤的人(比如我)。

Eager以标准方式工作,其中变量仅在引用时才持续。如果你停止引用它们,它们就会消失。

要进行变量共享,如果你要使用numpy(或其他任何东西)进行计算,你会做同样的事情:将变量存储在一个对象中,然后重用这个对象。

这就是为什么渴望与keras API有如此多的亲和力的原因,因为keras主要处理对象。

所以再看看你的函数就numpy而言(对于像我这样从图中恢复的人来说很有用)。您是否期望对foo的两次调用返回相同的数组对象?当然不是。

答案 1 :(得分:1)

tensorflow/python/ops/variable_scope.py中的文档似乎已更新。

来自line 310

  

“重用:布尔值,无或tf.AUTO_REUSE。控制变量的重用或创建。启用急切执行后,此参数始终被强制为False。”

并且来自line 2107

  

“启用急切执行后,除非EagerVariableStore或模板当前处于活动状态,否则始终会创建新变量。”

答案 2 :(得分:1)

我发现在Eager Execution中重用变量是最简单的方法,只需将引用传递给周围的同一变量即可:

import tensorflow as tf
tf.enable_eager_execution()
import numpy as np

class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()

    def build(self, input_shape):
        # bias specific for each layer
        self.B = self.add_variable('B', [1])

    def call(self, input, A):
        # some function involving input, common weights, and layer-specific bias
        return tf.matmul(input, A) + self.B

class MyModel(tf.keras.Model):    
    def __init__(self):
        super(MyModel, self).__init__()

    def build(self, input_shape):
        # common vector of weights
        self.A = self.add_variable('A', [int(input_shape[-1]), 1])

        # layers which will share A
        self.layer1 = MyLayer()
        self.layer2 = MyLayer()

    def call(self, input):
        result1 = self.layer1(input, self.A)
        result2 = self.layer2(input, self.A)
        return result1 + result2

if __name__ == "__main__":
    data = np.random.normal(size=(1000, 3))
    model = MyModel()
    predictions = model(data)
    print('\n\n')
    model.summary()
    print('\n\n')
    print([v.name for v in model.trainable_variables])

输出为:

enter image description here

因此,我们有一个共享的尺寸3的权重参数my_model/A和两个尺寸1的偏置参数my_model/my_layer/Bmy_model/my_layer_1/B,总共5个可训练参数。该代码是独立运行的,因此可以随意使用它。