为TPU上的GAN训练准备数据

时间:2019-09-24 15:54:04

标签: tensorflow generative-adversarial-network tpu gan

我想使用TPUGANEstimator在TPU上训练带有附加条件图像的pix2pix GAN的修改版(您可以在此处找到原始出版物(非TPU):https://phillipi.github.io/pix2pix/)。为了训练生成器,我需要将其输入input_image和condition_image,对于鉴别器,我将输入target_image和condition_image。我的问题是Estimator如何为在input_fn中创建的字典找到模型的正确输入。这是我的输入功能代码的伪代码:

def input_fn(mode, params):
    is_train = mode == tf.estimator.ModeKeys.TRAIN

    # Yields a triplet of input_image, condition_image, target_image
    data_gen = training_generator if is_train else test_generator 

    dataset = tf.data.Dataset.from_generator(data_gen,
        ({'input': tf.float32, 'condition': tf.float32, 'target': tf.float32}),
        output_shapes=({'input': tf.TensorShape(shape), 'condition': tf.TensorShape(shape),
                        'target': tf.TensorShape(shape)}))
    if is_train:
        dataset = dataset.shuffle(buffer_size=int(config["shuffle_buffer_ratio"] * config["batch_size"]))

    dataset = dataset.prefetch(config["max_buffer"])
    dataset = dataset.batch(config["batch_size"])
    # return dataset

    iterator = dataset.make_one_shot_iterator()
    next_item = iterator.get_next()
    return next_item

第一个问题是定义生成器以使用字典中的两个输入:

def generator_fn(input_dict, mode='TRAIN', scope='Generator'):
    input = input_dict[‘input’]
    condition = input_dict[‘condition’]
    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        x = generator_network(x)
    return x

第二个问题是如何在鉴别器的伪图像和有效图像之间切换序列(可能使用joint_train参数):

def discriminator_fn(image, input_dict, scope='Discriminator'):
    x = discriminator_network(x)
return x

我根据文档定义了TPUGANEstimator,但找不到如何生成正确序列的良好教程/示例。

gan_estimator = tfgan.estimator.TPUGANEstimator(
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    generator_loss_fn=tfgan.losses.minimax_generator_loss,
    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss,
    generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
    discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
    joint_train=False,
    gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
    model_dir=config['model_dir'],
    params=params,
    use_tpu=False,
    train_batch_size=2,
    eval_batch_size=2,
    config=t_config)

while cur_step < number_of_steps:
    print("Running gan estimator: {}".format(cur_step))
    gan_estimator.train(train_input_fn, steps=cur_step)

非常感谢您提前提供帮助。

0 个答案:

没有答案