Keras实施改进的生成性对抗网络

时间:2018-05-03 13:24:58

标签: python keras generative-adversarial-network

我正在使用a keras implemented code来开发Wasserstein Generative Adversarial Gans的改进版本。我正在使用提供的Mnist数据库测试代码。在代码中,当读取Mnist时,还会加载图像的标签。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

鉴别器的损失函数:

 discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],[positive_y, negative_y, dummy_y])) 

对于发电机:

 generator_loss.append(generator_model.train_on_batch(np.random.rand(BATCH_SIZE, 100), positive_y))

在GAN过程中,这些标签被忽略了。如何使用该信息来生成特定标签的图像?

编辑:我注意到我需要更改的行数在203到217之间

discriminator.trainable = False
generator_input = Input(shape=(100,))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs= 
[discriminator_layers_for_generator])
//We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9), 
    loss=wasserstein_loss)
...

real_samples = Input(shape=X_train.shape[1:])
generator_input_for_discriminator = Input(shape=(100,))
generated_samples_for_discriminator = 
generator(generator_input_for_discriminator)
discriminator_output_from_generator = 
       discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)

我想我还需要修改模型构建器make_generator和make_discriminator。

0 个答案:

没有答案