在张量流上应用变量重用时,GAN转换会出现问题

时间:2018-11-10 22:43:26

标签: python tensorflow generative-adversarial-network

我正在构建GAN,当我两次使用重复使用来调用区分器时,我的GAN开始出现分歧。我首先创建了鉴别器,如下所示:

def discriminator(self, x_past, x_future, gen_future):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
        with tf.variable_scope("disc") as disc:
            gen_future = tf.concat([gen_future, x_past], 2)
            x_future = tf.concat([x_future, x_past], 2)
            x_in = tf.concat([gen_future, x_future], 0)
            conv1 = tf.layers.conv1d(inputs=x_in, filters=20, kernel_size=3, strides=1,
                padding='same', activation=tf.nn.relu)
            max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
            conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=3, kernel_size=2, strides=1,
            padding='same', activation=tf.nn.relu)
            max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')

            # Flatten and add dropout
            flat = tf.reshape(max_pool_2, (-1, 9))
            flat = tf.nn.dropout(flat, keep_prob=self.keep_prob)

            # Predictions
            logits = tf.layers.dense(flat, 2)

            y_true = logits[:self.batch_size]
            y_gen = logits[self.batch_size:]

            return y_true, y_gen

我这样称呼它:

y_true, y_gen = self.discriminator(x_past, x_future, gen_future)

我能够正确地训练GAN。现在,我需要使用重用功能来调用它,而不必一次发送真实和伪造的数据。我将其更改为:

def discriminator(self, x_past, x_future, reuse=False):
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
        with tf.variable_scope("disc", reuse=reuse) as disc:
            x_in = tf.concat([x_future, x_past], 2)
            conv1 = tf.layers.conv1d(inputs=x_in, filters=20, kernel_size=3, strides=1,
                padding='same', activation=tf.nn.relu)
            max_pool_1 = tf.layers.max_pooling1d(inputs=conv1, pool_size=2, strides=2, padding='same')
            conv2 = tf.layers.conv1d(inputs=max_pool_1, filters=3, kernel_size=2, strides=1,
            padding='same', activation=tf.nn.relu)
            max_pool_2 = tf.layers.max_pooling1d(inputs=conv2, pool_size=2, strides=2, padding='same')

            # Flatten and add dropout
            flat = tf.reshape(max_pool_2, (-1, 9))
            flat = tf.nn.dropout(flat, keep_prob=self.keep_prob)

            # Predictions
            logits = tf.layers.dense(flat, 2)
            return logits

并这样称呼它:

y_true = self.discriminator(x_past, x_future)
y_gen = self.discriminator(x_past, gen_future, reuse=True)

现在它开始发散了。知道为什么吗?

0 个答案:

没有答案