我使用版本为2.2.4的keras在CelebA数据集上训练了DCGAN。我对keras功能API实现感到困惑,例如关于DCGAN体系结构中BN层的 predict 。
在训练过程中,我需要使用 generator.predict(noise)来获取一些用于鉴别器训练的假图像。在训练过程中,我需要抽取样本以显示网络是否使用进行了改进> generator.predict(noise)。我知道,在进行鉴别器训练期间,由生成器生成的假图像应以训练模式通过生成器中的BN层,以及何时从发生器生成样本的方式应该在测试模式中完成。由于在两种情况下我都使用 predict 方法,我想知道如何确保培训和测试模式正确使用该方法。任何人都可以告诉我是否完成了培训,当我使用预测方法时BN层处于哪种模式。谢谢。
代码在一次迭代中是这样的:
noise = np.random.normal(0, 1, size=(batch_size, noise_dims))
fake_batch = generator.predict(noise) # are the BN layers in training mode?
real_batch = data_loader.load_data(batch_size)
# train discriminator
discriminator.train_on_batch(real_batch, np.ones(batch_size)*0.9)
discriminator.train_on_batch(fake_batch, np.zeros(batch_size))
# train generator
discriminator.trainable = False
noise = np.random.normal(0, 1, size=(batch_size, noise_dims))
combine_model.train_on_batch(noise, np.ones(batch_size))
# combine_model is built by Model(noise, discriminator(generator(noise)))
# I also wonder whether this generator and discriminator call are in training mode
if iter % sample_interval == 0:
noise = np.random.normal(0, 1, size=(sample_size, noise_dims))
generator.predict(noise)
# are the BN layers run in test mode?
想知道我对DCGAN中BN层的理解是否正确以及keras中的预测方法如何处理BN或辍学中的训练或测试模式。最后,如果可能的话,如何手动控制使用keras(set_learning_phase?)的DCGAN培训。