我想使用TPUGANEstimator在TPU上训练带有附加条件图像的pix2pix GAN的修改版(您可以在此处找到原始出版物(非TPU):https://phillipi.github.io/pix2pix/)。为了训练生成器,我需要将其输入input_image和condition_image,对于鉴别器,我将输入target_image和condition_image。我的问题是Estimator如何为在input_fn中创建的字典找到模型的正确输入。这是我的输入功能代码的伪代码:
def input_fn(mode, params):
is_train = mode == tf.estimator.ModeKeys.TRAIN
# Yields a triplet of input_image, condition_image, target_image
data_gen = training_generator if is_train else test_generator
dataset = tf.data.Dataset.from_generator(data_gen,
({'input': tf.float32, 'condition': tf.float32, 'target': tf.float32}),
output_shapes=({'input': tf.TensorShape(shape), 'condition': tf.TensorShape(shape),
'target': tf.TensorShape(shape)}))
if is_train:
dataset = dataset.shuffle(buffer_size=int(config["shuffle_buffer_ratio"] * config["batch_size"]))
dataset = dataset.prefetch(config["max_buffer"])
dataset = dataset.batch(config["batch_size"])
# return dataset
iterator = dataset.make_one_shot_iterator()
next_item = iterator.get_next()
return next_item
第一个问题是定义生成器以使用字典中的两个输入:
def generator_fn(input_dict, mode='TRAIN', scope='Generator'):
input = input_dict[‘input’]
condition = input_dict[‘condition’]
with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
x = generator_network(x)
return x
第二个问题是如何在鉴别器的伪图像和有效图像之间切换序列(可能使用joint_train参数):
def discriminator_fn(image, input_dict, scope='Discriminator'):
x = discriminator_network(x)
return x
我根据文档定义了TPUGANEstimator,但找不到如何生成正确序列的良好教程/示例。
gan_estimator = tfgan.estimator.TPUGANEstimator(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=tfgan.losses.minimax_generator_loss,
discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss,
generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5),
joint_train=False,
gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
model_dir=config['model_dir'],
params=params,
use_tpu=False,
train_batch_size=2,
eval_batch_size=2,
config=t_config)
while cur_step < number_of_steps:
print("Running gan estimator: {}".format(cur_step))
gan_estimator.train(train_input_fn, steps=cur_step)
非常感谢您提前提供帮助。