我正在使用a keras implemented code来开发Wasserstein Generative Adversarial Gans的改进版本。我正在使用提供的Mnist数据库测试代码。在代码中,当读取Mnist时,还会加载图像的标签。
(X_train, y_train), (X_test, y_test) = mnist.load_data()
鉴别器的损失函数:
discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],[positive_y, negative_y, dummy_y]))
对于发电机:
generator_loss.append(generator_model.train_on_batch(np.random.rand(BATCH_SIZE, 100), positive_y))
在GAN过程中,这些标签被忽略了。如何使用该信息来生成特定标签的图像?
编辑:我注意到我需要更改的行数在203到217之间
discriminator.trainable = False
generator_input = Input(shape=(100,))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs=
[discriminator_layers_for_generator])
//We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
loss=wasserstein_loss)
...
real_samples = Input(shape=X_train.shape[1:])
generator_input_for_discriminator = Input(shape=(100,))
generated_samples_for_discriminator =
generator(generator_input_for_discriminator)
discriminator_output_from_generator =
discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)
我想我还需要修改模型构建器make_generator和make_discriminator。