在CycleGAN教程中训练生成器时,我们应该停止训练鉴别器吗?

时间:2019-11-06 09:06:21

标签: python tensorflow deep-learning generative-adversarial-network gan

在Tensorlfow教程为CycleGAN提供的代码中,他们同时训练了鉴别器和生成器。


    def train_step(real_x, real_y):
      # persistent is set to True because the tape is used more than
      # once to calculate the gradients.
      with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.

        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

      # Calculate the gradients for generator and discriminator
      generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                            generator_g.trainable_variables)
      generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                            generator_f.trainable_variables)

      discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                                discriminator_x.trainable_variables)
      discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                                discriminator_y.trainable_variables)

      # Apply the gradients to the optimizer
      generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                                generator_g.trainable_variables))

      generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                                generator_f.trainable_variables))

      discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                    discriminator_x.trainable_variables))

      discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                    discriminator_y.trainable_variables))

但是,在训练GAN网络时,我们需要在训练生成器网络时停止训练鉴别器。 使用它有什么好处?

2 个答案:

答案 0 :(得分:0)

在GAN中,您不会停止训练D或G。它们是同时训练的。 在这里,他们首先计算每个网络的梯度值(在计算电流损耗之前不改变D或G),然后使用这些值更新权重。 您的问题不清楚,这有什么好处?

答案 1 :(得分:0)

为了以对抗方式进行训练,鉴别器和生成器网络的梯度应该分别更新。鉴别器变得更强,因为生成器产生更真实的样本,反之亦然。如果你一起更新这些网络,“对抗性”训练就不会发生——据我所知,你不太可能通过这种方式获得令人满意的合成样本。