在tf.keras中的GAN实现中设置.trainable变量

时间:2019-11-11 15:04:09

标签: python tensorflow tensorflow2.0 tf.keras

在GAN的实现中,我对.trainable的{​​{1}}语句感到困惑。

给出以下代码片段(摘自this repo):

tf.keras.model

在模型class GAN(): def __init__(self): ... # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates imgs z = Input(shape=(self.latent_dim,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity validity = self.discriminator(img) # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): ... return Model(noise, img) def build_discriminator(self): ... return Model(img, validity) def train(self, epochs, batch_size=128, sample_interval=50): # Load the dataset (X_train, _), (_, _) = mnist.load_data() # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Generate a batch of new images gen_imgs = self.generator.predict(noise) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid) 的定义期间,鉴别器的权重设置为self.combined,但从不打开。

仍然,在训练循环中,鉴别器的权重将随行而变化:

self.discriminator.trainable = False

,并在以下期间保持不变:

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

我没想到。

当然,这是训练GAN的正确(迭代)方式,但是我不明白为什么我们不必先通过# Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid) 就能对鉴别器进行一些训练。

如果有人对此做出解释,我想这是理解的关键。

1 个答案:

答案 0 :(得分:1)

当您对github存储库中的代码有疑问时,通常最好检查一下问题(打开和关闭)。 This issue解释了为什么将标志设置为False。它说,

  

由于self.discriminator.trainable = False是在鉴别符编译后设置的,因此不会影响鉴别符的训练。但是,由于它是在组合模型编译之前设置的,因此在训练组合模型时将冻结区分层。

还讨论了freezing keras layers