训练GAN时批次中的所有图像均相同

时间:2019-06-10 09:14:56

标签: python tensorflow generative-adversarial-network

def train(self,gray_scale_image_dataset,color_image_dataset,test_image):
    SEED= 50
    random.seed(SEED)
    generator = self.generator_model()
    discriminator = self.discriminator_model()

    gray_scale_images = gray_scale_image_dataset
    colored_images = color_image_dataset

    gen_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
    dis_optimizer = tf.train.AdamOptimizer(self.learning_rate,beta1=0.5)
    for eachEpoch in range(self.epochs):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            for i in range(20):
                random.shuffle(gray_scale_image_dataset)
                random.shuffle(color_image_dataset)
            gray_scale_dataset_image = gray_scale_images[:self.batch_size]
            print(gray_scale_dataset_image.shape)
            color_dataset_image_batch = colored_images[:self.batch_size]
            #lets see which images are being trained

            self.draw_train_images(color_dataset_image_batch) 
            generated_image = generator(gray_scale_dataset_image)
            real_output = discriminator(color_dataset_image_batch)

            fake_output = discriminator(generated_image)
            print("What  Discriminator Thought about real_output = {} and fake output = {}".format(real_output[0],fake_output[0]))
            gen_loss = self.generator_loss(fake_output,generated_image,color_dataset_image_batch)
            dis_loss = self.discriminator_loss(fake_output,real_output)
            print("GEN LOSS {} and DISC = {}".format(gen_loss[0],dis_loss[0]))

        gen_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
        disc_gradients = disc_tape.gradient(dis_loss,discriminator.trainable_variables)
        gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
        dis_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

        print ("EPOCHS COMPLETED = {} ".format(eachEpoch))
        self.draw_images(generator,test_image)

这是我的训练功能,实际上是训练gan网络的。生成器和鉴别器是两个不同的网络,是在本文之后编写的,因此这没有问题。我的问题是正在传递的图像。我已经检查了图像,gray_scale_image_dataset,color_image_dataset和它们都工作良好。但是,当我转到self.draw_train_images函数并尝试在matplotlib中绘制它们时,网格中仅显示了第一张图像。整个网格充满了第一个图像,只有一个图像用于训练数据,所以我遇到了很多错误。关于这个问题有帮助吗?我在哪里弄乱东西?

0 个答案:

没有答案