Keras条件GAN训练不好

时间:2017-05-17 16:42:32

标签: python machine-learning keras

我正在尝试基于keras-dcgan(https://github.com/jacobgil/keras-dcgan)上的jacob代码构建条件GAN模型。

我假设的模型架构如下图:
the model architecture

原始论文: http://cs231n.stanford.edu/reports/2015/pdfs/jgauthie_final_report.pdf

对于生成器,我插入条件(在这种情况下,条件是一堆单热矢量)首先将它与噪声连接,然后通过生成器提供连接。

对于鉴别器,我通过与模型中间的平坦层连接来插入条件。

我的代码运行,但它生成一些随机图而不是特定数字。哪一步错了?我没有适当插入条件吗?

运行大约5500次迭代后的结果:
My result after running approximately 5500 iterations

代码:

import warnings
warnings.filterwarnings('ignore')

from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Input, merge
from keras.layers import Reshape, concatenate
from keras.layers.core import Activation
from keras.models import Model
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
import tensorflow as tf
from PIL import Image
import argparse
import math
K.set_image_dim_ordering('th')

# based on the labels below, we create a flattened array with 10 one-hot-vectors, and call it y_prime
labels = np.array([0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9])

def dense_to_one_hot(labels_dense, num_classes=10):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = np.arange(num_labels) * num_classes
    labels_one_hot = np.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot

# y_dim is the number of labels in one hot vector form, hence its 10
# y_prime is a 100*10 matrix, and len(y_p) = 100. Note that len(y_prime) must equate to batch_size for the matrices to be properly concatenated
# Also y_dim=10, which is the size of any one-hot vector
y_p = dense_to_one_hot(labels)
y_size = len(y_p)
y_dim = len(y_p[0])


#g_inputs is the input for generator
#auxiliary_input is the condition
#d_inputs is the input for discriminator
g_inputs = (Input(shape=(100,), dtype='float32'))
auxiliary_input = (Input(shape=(y_dim,), dtype='float32'))
d_inputs = (Input(shape=(1,28,28), dtype='float32'))

def generator_model():
    T = concatenate([g_inputs,auxiliary_input])
    T = (Dense(1024))(T)
    T = (Dense(128*7*7))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (Reshape((128, 7, 7), input_shape=(128*7*7,)))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(64, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(1, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    model = Model(input=[g_inputs,auxiliary_input], output=T)
    return model

def discriminator_model():
    T = (Convolution2D(filters= 64, kernel_size= (5,5), padding='same'))(d_inputs)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Convolution2D(128, 5, 5))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Flatten())(T)
    T = concatenate([T, auxiliary_input])
    T = (Dense(1024))(T)
    T = (Activation('tanh'))(T)
    T = (Dense(1))(T)
    T = (Activation('sigmoid'))(T)
    model = Model(input=[d_inputs,auxiliary_input], output=T)
    return model

def generator_containing_discriminator(generator, discriminator):
    T1 = generator([g_inputs, auxiliary_input])
    discriminator.trainable = False
    T2 = discriminator([T1,auxiliary_input])
    model = Model(input=[g_inputs, auxiliary_input], output=T2)
    return model

def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[0, :, :]
    return image


def train(BATCH_SIZE,y_prime):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            y_batch = dense_to_one_hot(y_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE])
            y_batch = np.concatenate((y_batch , y_prime))
            generated_images = generator.predict([noise,y_prime], verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch([X,y_batch], y)
            print("batch %d d_loss : %f" % (index, d_loss))
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch([noise,y_prime], [1] * BATCH_SIZE)
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)



train(100,y_p)

1 个答案:

答案 0 :(得分:0)

这是我用Keras构建条件GAN(CGAN)的代码:https://github.com/hklchung/GAN-GenerativeAdversarialNetwork/tree/master/CGAN

在MNIST上经过5个时期后,我得到了: MNIST CGAN output

,并且在CelebsA数据集上经过50个纪元后: CelebA CGAN output

我的经验是,如果在20个时间段后仍看不到任何好的结果,则说明您的模型出了点问题,再对它进行训练都不会提高图像质量。