我一直在关注GAN的图像着色。我曾尝试自己实施研究论文,但遇到了问题。在特定点之后,生成器和鉴别器的损失不会收敛。因此,我认为它们的实现存在问题。这是代码
def generator_loss(self,fake_output_discri,generated_image_from_generator, actual_image,regularizer_lambda=0.01):
mse = tf.reduce_mean(regularizer_lambda*tf.keras.losses.mean_absolute_error(generated_image_from_generator, actual_image))
return tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(fake_output_discri),logits = fake_output_discri) + mse
def discriminator_loss(self,generated_image_from_generator,actual_image):
actual_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(actual_image), logits = actual_image )
fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(generated_image_from_generator), logits = generated_image_from_generator)
return actual_loss + fake_loss
假设传递给他们的参数正确,我做错了什么吗?我尝试打印生成器和鉴别器损耗,并且在某些时期后保持不变!这是完成大部分工作的火车功能!
def train(self,gray_scale_image_dataset,color_image_dataset,test_image):
generator = self.generator_model()
discriminator = self.discriminator_model()
gen_optimizer = tf.train.AdamOptimizer(self.learning_rate)
dis_optimizer = tf.train.AdamOptimizer(self.learning_rate)
for eachEpoch in range(self.epochs):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
for i in range(20):
random.shuffle(gray_scale_image_dataset)
random.shuffle(color_image_dataset)
gray_scale_dataset_image = gray_scale_image_dataset[:self.batch_size]
color_dataset_image_batch = color_image_dataset[:self.batch_size]
generated_image = generator(gray_scale_dataset_image)
real_output = discriminator(color_dataset_image_batch)
fake_output = discriminator(generated_image)
gen_loss = self.generator_loss(fake_output,generated_image,color_dataset_image_batch)
dis_loss = self.discriminator_loss(fake_output,real_output)
print("generator = {} discriminator = {}".format(gen_loss,dis_loss))
gen_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
disc_gradients = disc_tape.gradient(dis_loss,discriminator.trainable_variables)
print("APPLYING GRADENTS")
gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
dis_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
print ("EPOCHS COMPLETED = {} ".format(eachEpoch))
#for drawing test_image
self.draw_images(generator,test_image)