GAN没有很好地进行训练

时间:2018-06-09 10:19:12

标签: python keras generative-adversarial-network

我使用keras编写了一个GAN模型但是训练并不顺利。生成器模型总是返回一个裸噪声图像(28x28大小),而不是类似于mnist数据集的东西。但是,这并没有给我任何错误,当涉及到训练鉴别器模型将成为trainable=False,这不是我想要做的。

如果此实施不好,请告诉我。有人可以帮忙吗?

import os
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, BatchNormalization
from keras.optimizers import SGD, Adam, RMSprop
from keras.datasets import mnist
from keras.regularizers import l1_l2

def plot_generated(noise, Generator):
    image_fake = Generator.predict(noise)
    plt.figure(figsize=(10,8))
    plt.show()
    plt.close()

def plot_metircs(metrics, epoch=None):
    plt.figure(figsize=(10,8))
    plt.plot(metrics['d'], label='discriminative loss', color='b')
    plt.legend()
    plt.show()
    plt.close()

    plt.figure(figsize=(10,8))
    plt.plot(metrics['g'], label='generative loss', color='r')
    plt.legend()
    plt.show()
    plt.close()

def Generator():
    model = Sequential()
    LeakyReLU = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)
    model.add(Dense(input_dim=100, units=128, activation=LeakyReLU, name='g_input'))
    model.add(Dense(input_dim=128, units=784, activation='tanh', name='g_output'))
    return model

def Discriminator():
    model = Sequential()
    LeakyReLU = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)
    model.add(Dense(input_dim=784, units=128, activation=LeakyReLU, name='d_input'))
    model.add(Dense(input_dim=128, units=1, activation='sigmoid', name='d_output'))
    model.compile(loss='binary_crossentropy', optimizer='Adam')
    return model

def Generative_Adversarial_Network(Generator, Discriminator):
    model = Sequential()
    model.add(Generator)
    model.add(Discriminator)
    # train only generator in the entire GAN architecture
    Discriminator.trainable = False
    model.compile(loss='binary_crossentropy', optimizer='Adam')
    return model

def Training(z_input_size, Generator, Discriminator, GAN, loss_dict, X_train, epoch, batch, smooth):
    for e in range(epoch):
        # z: noise, used for input of G to generate fake image based on this noise! it's like a seed 
        noise = np.random.uniform(-1, 1, size=[batch, z_input_size])
        image_fake = Generator.predict_on_batch(noise)

        # sampled real_image from dataset
        rand_train_index = np.random.randint(0, X_train.shape[0], size=batch)
        image_real = X_train[rand_train_index, :]

        # concatenate real and fake images
        """
        X = [
            image_real => label : 1 (we can multiply a smoothing factor)
            image_fake => label : 0
            ]
        """
        X = np.vstack((image_real, image_fake))
        y = np.zeros(len(X))

        # putting label "1" to image_real
        y[len(image_real):] = 1*(1 - smooth)
        y = y.astype(int)

        # train only discriminator
        d_loss = Discriminator.train_on_batch(x=X, y=y)

        # NOTE: remember?? we set discriminator OFF during the training of GAN!
        # So, we can safely train only generator, weight of discriminator set fixed!
        g_loss = GAN.train_on_batch(x=noise, y=y[len(noise):])

        loss_dict['d'].append(d_loss)
        loss_dict['g'].append(g_loss)

        if e%1000 == 0:
            plt.imshow(image_fake)
            plt.show()
            plot_generated(noise, Generator)

    plot_metircs(loss_dict)
    return "done!"


Gen = Generator()
Dis = Discriminator()
GAN = Generative_Adversarial_Network(Gen, Dis)
GAN.summary()
Gen.summary()
Dis.summary()

gan_losses = {"d":[], "g":[], "f":[]}
epoch = 30000
batch = 1000
smooth = 0.9
z_input_size = 100
row, col = 28, 28

z_group_matrix = np.random.uniform(0, 1, examples*z_input_size)
z_group_matrix = z_group_matrix.reshape([9, z_input_size])
print(z_group_matrix.shape)

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train.reshape(X_train.shape[0], row*col), X_test.reshape(X_test.shape[0], row*col)
X_train.astype('float32')
X_test.astype('float32')
X_train, X_test = X_train/255, X_test/255
print('X_train shape: ', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

Training(z_input_size, Gen, Dis, GAN, loss_dict=gan_losses, X_train=X_train, epoch=epoch, batch=batch, smooth=smooth)

1 个答案:

答案 0 :(得分:0)

模型本身是正确的。

我建议进行一些小的更改:

  1. 平滑度0.9太大。使其接近0.1。
  2. 您的泄漏系数为0.2,通常是非常小的小数,接近0;走走 0.01 / 0.02。
  3. 批量大约为400
  4. 大约2000年的时代
  5. 最后,以较大的阈值尽早停止。