在Tensorlfow教程为CycleGAN提供的代码中,他们同时训练了鉴别器和生成器。
def train_step(real_x, real_y): # persistent is set to True because the tape is used more than # once to calculate the gradients. with tf.GradientTape(persistent=True) as tape: # Generator G translates X -> Y # Generator F translates Y -> X. fake_y = generator_g(real_x, training=True) cycled_x = generator_f(fake_y, training=True) fake_x = generator_f(real_y, training=True) cycled_y = generator_g(fake_x, training=True) # same_x and same_y are used for identity loss. same_x = generator_f(real_x, training=True) same_y = generator_g(real_y, training=True) disc_real_x = discriminator_x(real_x, training=True) disc_real_y = discriminator_y(real_y, training=True) disc_fake_x = discriminator_x(fake_x, training=True) disc_fake_y = discriminator_y(fake_y, training=True) # calculate the loss gen_g_loss = generator_loss(disc_fake_y) gen_f_loss = generator_loss(disc_fake_x) total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y) # Total generator loss = adversarial loss + cycle loss total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y) total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x) disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x) disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y) # Calculate the gradients for generator and discriminator generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables) generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables) discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables) discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables) # Apply the gradients to the optimizer generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables)) generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables)) discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables)) discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))
但是,在训练GAN网络时,我们需要在训练生成器网络时停止训练鉴别器。 使用它有什么好处?
答案 0 :(得分:0)
在GAN中,您不会停止训练D或G。它们是同时训练的。 在这里,他们首先计算每个网络的梯度值(在计算电流损耗之前不改变D或G),然后使用这些值更新权重。 您的问题不清楚,这有什么好处?
答案 1 :(得分:0)
为了以对抗方式进行训练,鉴别器和生成器网络的梯度应该分别更新。鉴别器变得更强,因为生成器产生更真实的样本,反之亦然。如果你一起更新这些网络,“对抗性”训练就不会发生——据我所知,你不太可能通过这种方式获得令人满意的合成样本。