我想稍微了解一下WGAN,但我有几个问题。 我遵循了以下article和具有git repo的WGAN实现方案。
直到现在,当我想训练GAN模型时,我都有一个epochs循环和一个批处理循环,就像这样:
for epoch in range(epochs):
for batch in dataset:
noise = tf.random.normal([BATCH_SIZE, noise_dim])
generated_images = generator(noise, training=True)
real_imgs_predictions = discriminator(batch, training=True)
fake_imgs_predictions = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_imgs_predictions) #my own loss function
disc_loss = discriminator_loss(real_imgs_predictions , fake_imgs_predictions)#my own loss function
现在在此git repo中,在文章中,我看到了一个纪元循环,而不是分批运行,而是使用了train_on_batch
函数,在我看来(也许我错了。 ),每个时期只使用一批:
for epoch in range(epochs):
for _ in range(self.n_critic):
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the critic
d_loss_real = self.critic.train_on_batch(imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
1。这是GAN和WGAN之间的区别之一吗?
2。只使用一批批次有什么好处?
3。我知道在WGAN中,我们要训练批评者/区分者而不是生成者,但是为什么我们每个时期不使用很多批处理?