Tensorflow-tf.variable_scope,GAN

时间:2018-06-22 15:37:07

标签: tensorflow machine-learning neural-network generative-adversarial-network

我正在尝试为项目构建GAN,我真的很想了解tensorflow的variable_scope中的此变量共享是如何工作的。

对于构建GAN,我有一个生成器网络和两个鉴别器网络: 一个鉴别器被提供给真实图像,而一个鉴别器被提供给生成器创建的伪图像。重要的是,提供有真实图像的鉴别器和提供有伪图像的鉴别器需要共享相同的权重。为此,我需要共享权重。

我有一个鉴别器和生成器定义,可以这样说:

def discriminator(images, reuse=False):
    with variable_scope("discriminator", reuse=reuse):

        #.... layer definitions, not important here
        #....
        logits = tf.layers.dense(X, 1)
        logits = tf.identity(logits, name="logits")
        out = tf.sigmoid(logits, name="out")
        # 14x14x64
    return logits, out

def generator(input_z, reuse=False):
    with variable_scope("generator", reuse=reuse):

        #.. not so important 
        out = tf.tanh(logits)

    return out

现在生成器和鉴别函数被调用:

g_model = generator(input_z)
d_model_real, d_logits_real = discriminator(input_real)

#Here , reuse=True should produce the weight sharing between d_model_real, d_logits_real
#and d_model_fake and d_logits_fake.. why?
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)

为什么 second 调用中的reuse = True语句会产生权重共享?据我了解,您需要在第一个调用中决定重用变量,以便可以在程序的稍后位置使用它们。

如果有人可以向我解释这一点,我将感到非常高兴,但我找不到这个主题的很好来源,这对我来说确实很令人困惑和复杂。 谢谢!

1 个答案:

答案 0 :(得分:3)

在后台使用tf.get_variable()创建变量。

此函数将在变量名之前添加作用域,并在创建新变量之前检查其是否存在。

例如,如果您在范围"fc"中并调用tf.get_variable("w", [10,10]),则变量名称将为"fc/w:0"

现在,当您第二次执行此操作时,如果reuse=True,范围将再次为"fc",并且get_variable将重用变量"fc/w:0"

但是,如果reuse=False会出现错误,因为变量"fc/w:0"已经存在,提示您使用其他名称或使用reuse=True

示例:

In [1]: import tensorflow as tf

In [2]: with tf.variable_scope("fc"):
   ...:      v = tf.get_variable("w", [10,10])
   ...:

In [3]: v
Out[3]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>

In [4]: with tf.variable_scope("fc"):
   ...:      v = tf.get_variable("w", [10,10])
   ...:
ValueError: Variable fc/w already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?

In [5]: with tf.variable_scope("fc", reuse=True):
   ...:      v = tf.get_variable("w", [10,10])
   ...:

In [6]: v
Out[6]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>

请注意,代替共享权重,您只能实例化一个鉴别器。然后,您可以决定使用placeholder_with_default将其提供给真实数据还是生成的数据。