在CIFAR10数据集上训练深度卷积GAN时,损失没有变化

时间:2019-09-17 18:25:50

标签: python tensorflow keras

该模型根本没有训练。我认为梯度无法正常流动,或者可能存在一些逻辑错误。

即使经过几个时期,损失也没有变化。 我尝试了许多代码变体,但无法调试问题。

鉴别器和发电机损耗没有变化。

下面是我的代码:

from keras.layers import Conv2D, Dense, LeakyReLU, BatchNormalization, Reshape 
from keras.layers import Flatten, Input, MaxPooling2D, Dropout, AlphaDropout, Conv2DTranspose
from keras.optimizers import Adam
from keras.models import Model, Sequential
from keras.datasets import cifar10
import keras.backend as K
import cv2
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt

class CifarGAN:
    def __init__(self, img_rows, img_cols, channels, epochs, batch_size, latent_dim):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.epochs = epochs
        self.batch_size = batch_size
        self.latent_dim = latent_dim

        self.optimizer = Adam(0.0002, beta_1=0.5)

        # build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(optimizer=self.optimizer, loss='binary_crossentropy', metrics=['accuracy'])

        # build and compile generator
        self.generator = self.build_generator()
        self.generator.compile(optimizer=self.optimizer, loss='binary_crossentropy')        

        # input noise
        noise = Input((self.latent_dim, ))
        img = self.generator(noise)

        # for the combined model we will only train the generator
        self.discriminator.trainable = False

        valid = self.discriminator(img)

        self.combined = Model(inputs=noise, outputs=valid, name='combined')
        self.combined.compile(optimizer=self.optimizer, loss='binary_crossentropy')
        # self.combined.summary()

    def build_generator(self):

        model = Sequential(name='generator')

        model.add(Dense(256*4*4, name='gen_dense_2'))
        model.add(LeakyReLU(alpha=0.2, name='gen_lr_2'))

        model.add(Reshape((4, 4, 256), name='gen_reshape_1'))

        model.add(Conv2DTranspose(128, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer='he_normal'))
        # model.add(BatchNormalization(name='gen_bc_1'))
        model.add(LeakyReLU(alpha=0.2, name='gen_lr_3'))
        # model.add(AlphaDropout(0.1))

        model.add(Conv2DTranspose(128, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer='he_normal'))
        # model.add(BatchNormalization(name='gen_bc_2'))
        model.add(LeakyReLU(alpha=0.2, name='gen_lr_4'))
        # model.add(AlphaDropout(0.1))

        model.add(Conv2DTranspose(256, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer='he_normal'))
        # model.add(BatchNormalization(name='gen_bc_3'))
        model.add(LeakyReLU(alpha=0.2, name='gen_lr_5'))
        # model.add(AlphaDropout(0.1))

        model.add(Conv2D(3, kernel_size=(3, 3), padding='same', activation='tanh'))

        noise = Input((self.latent_dim, ), name='gen_input')
        img = model(noise)
        # model.summary()
        return Model(noise, img)


    def build_discriminator(self):
        model = Sequential(name='discriminator')

        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal', name='dis_conv_1'))
        model.add(BatchNormalization(name='dis_bc_1'))
        model.add(LeakyReLU(alpha=0.2, name='dis_lr_1'))

        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal', name='dis_conv_2'))
        model.add(BatchNormalization(name='dis_bc_2'))
        model.add(LeakyReLU(alpha=0.2, name='dis_lr_2'))

        model.add(Flatten(name='dis_flatten'))

        model.add(Dense(256, name='dis_dense_1'))
        model.add(LeakyReLU(alpha=0.2, name='dis_lr_4'))

        model.add(Dense(1, activation='sigmoid'))

        image = Input(self.img_shape, name='dis_input')
        valid = model(image)
        # model.summary()

        return Model(image, valid)

    def generate_real_samples(self):
        gen_imgs = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
        labels = np.ones((self.batch_size, 1))
        return gen_imgs, labels

    def generate_fake_samples(self):
        noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
        imgs = self.generator.predict(noise)
        labels = np.zeros((self.batch_size, 1))
        return imgs, labels

    def sample_original_image(self):
        idx = np.random.randint(0, self.samples, self.batch_size)
        imgs = self.X_train[idx]
        labels = np.ones((self.batch_size, 1))
        return imgs, labels

    def get_data(self):
        (X_train, _), (_, _) = cifar10.load_data()
        self.samples = X_train.shape[0]
        # Rescale the image to -1 to 1
        X_train = (X_train.astype('float32') - 127.5) / 127.5
        return X_train

    def train(self):
        self.X_train = self.get_data()

        mini_batch = int(self.batch_size / 2)
        for epoch in range(1, self.epochs+1):
            for batch in range(mini_batch):

                imgs, labels = self.sample_original_image()
                gen_imgs, gen_labels = self.generate_fake_samples()

                # Train the discriminator
                d_loss_real, _ = self.discriminator.train_on_batch(imgs, labels)
                d_loss_fake, _ = self.discriminator.train_on_batch(gen_imgs, gen_labels)

                # The generator wants the discriminator to label the generated samples
                # as valid (ones)
                noise, valid_y = self.generate_real_samples()

                # Train the generator
                g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress every 100 iterations:
            if epoch % 10 == 0:
                print ("epoch: %d / %d [D loss real: %f, D loss fake: %f] [G loss: %f]" % (epoch, self.epochs, d_loss_real, d_loss_fake, g_loss))


img_rows, img_cols, channels, epochs, batch_size, latent_dim = 32, 32, 3, 100, 64, 100
cifar_gan = CifarGAN(img_rows, img_cols, channels, epochs, batch_size, latent_dim)
cifar_gan.train()

时代

epoch: 0 / 100 [D loss real: 0.000000, D loss fake: 15.942385] [G loss: 0.000000]
epoch: 10 / 100 [D loss real: 0.000000, D loss fake: 15.942385] [G loss: 0.000000]
epoch: 20 / 100 [D loss real: 0.000000, D loss fake: 15.942385] [G loss: 0.000000]

谢谢。

0 个答案:

没有答案