在GANEstimator中使用对抗性损失

时间:2018-04-17 11:08:51

标签: python tensorflow computer-vision deep-learning keras

我现在正试图将L1像素丢失和对抗性损失结合起来学习自动编码图像。代码如下。

gan_model = tfgan.gan_model(
    generator_fn=nets.autoencoder,
    discriminator_fn=nets.discriminator,
    real_data=images,
    generator_inputs=images)

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss,
    weight_factor=FLAGS.weight_factor)

# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

# Run the train ops in the alternating training scheme.
tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)

但是,我想使用GANEstimator来简化代码。 GANEstimator的典型例子如下。

gan_estimator = tfgan.estimator.GANEstimator(
    model_dir,
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))

# Train estimator.
gan_estimator.train(train_input_fn, steps)

有人知道如何在GANEstimator中使用 combine_adversarial_loss 吗?

感谢。

2 个答案:

答案 0 :(得分:1)

我刚刚遇到了相同的问题(此解决方案适用于TensorFlow r1.12)。

通读代码,tfgan.losses.combine_adversarial_lossgan_loss元组,并用合并对抗性损失替换生成器损失。这意味着我们需要在估算器中替换generator_loss_fn。估计器的所有其他损失函数采用参数:gan_model, **kwargs。我们定义自己的函数并将其用作发电机损耗函数:

def combined_loss(gan_model, **kwargs):
    # Define non-adversarial loss - for example L1
    non_adversarial_loss = tf.losses.absolute_difference(
        gan_model.real_data, gan_model.generated_data)
    # Define generator loss
    generator_loss = tf.contrib.gan.losses.wasserstein_generator_loss(
        gan_model,
        **kwargs)
    # Combine these losses - you can specify more parameters
    # Exactly one of weight_factor and gradient_ratio must be non-None
    combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss(
        non_adversarial_loss,
        generator_loss,
        weight_factor=FLAGS.weight_factor,
        gradient_ratio=None,
        variables=gan_model.generator_variables,
        scalar_summaries=kwargs['add_summaries'],
        gradient_summaries=kwargs['add_summaries'])
    return combined_loss


gan_estimator = tf.contrib.gan.estimator.GANEstimator(
    model_dir,
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    generator_loss_fn=combined_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5))

有关参数的更多信息,请访问docs:tfgan.losses.wargs.combine_adversarial_loss 此外,**kwargs与组合的对抗损失功能不兼容,因此我在这里使用了一个小技巧。

答案 1 :(得分:0)

通过链接,GANEstimator具有以下参数:

 generator_loss_fn=None,
 discriminator_loss_fn=None,

generator_loss_fn应该是你的l1像素丢失。

discriminator_loss_fn应该是你的combine_adversarial_loss。