Keras中的AssertionError,GAN模型

时间:2017-05-16 19:43:31

标签: python machine-learning tensorflow keras

我正在尝试基于jacob的代码(https://github.com/jacobgil/keras-dcgan)构建条件GAN模型。但是,当模型编译时,generator_containing_discriminator模型会抛出一个断言错误。

我认为它可能与生成器和鉴别器模型中的多个输入和输出有关;当我尝试将两种模型结合起来时,keras不喜欢它。

基本上,我希望生成器接受两个输入:g_inputs和auxiliary_input,并产生两个输出:T和auxiliary_input。在这种情况下,我只是传递auxiliary_input。

我希望鉴别器有两个输入:d_inputs和auxiliary_input,并产生一个输出:T。

我确定尺寸是匹配的,但是你们知道为什么当我编译generator_containing_discriminator时这仍然不起作用?非常感谢!

我的代码:

import warnings
warnings.filterwarnings('ignore')

from keras import backend as K
from keras.layers import Dense, Input
from keras.layers import Reshape, concatenate
from keras.layers.core import Activation
from keras.models import Model
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Flatten
K.set_image_dim_ordering('th')


#g_inputs is the input for generator
#auxiliary_input is the condition
#d_inputs is the input for discriminator
g_inputs = (Input(shape=(110,), dtype='float32', name='g_inputs'))
auxiliary_input = (Input(shape=(10,), dtype='float32', name='auxiliary_input'))
d_inputs = (Input(shape=(1,28,28), dtype='float32', name='d_inputs'))


def generator_model():
    T = (Dense(1024))(g_inputs)
    T = (Dense(128*7*7))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (Reshape((128, 7, 7), input_shape=(128*7*7,)))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(64, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(1, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    model = Model(input=[g_inputs, auxiliary_input], output=[T,auxiliary_input])
    return model

def discriminator_model():
    T = (Convolution2D(filters= 64, kernel_size= (5,5), padding='same'))(d_inputs)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Convolution2D(128, 5, 5))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Flatten())(T)
    T = concatenate([T, auxiliary_input])
    T = (Dense(1024))(T)
    T = (Activation('tanh'))(T)
    T = (Dense(1))(T)
    T = (Activation('sigmoid'))(T)
    model = Model(input=[d_inputs,auxiliary_input], output=T)
    return model

def generator_containing_discriminator(generator, discriminator):
    T1 = generator([g_inputs, auxiliary_input])
    discriminator.trainable = False
    T2 = discriminator(T1)
    model = Model(input=[g_inputs, auxiliary_input], output=T2)
    return model



discriminator = discriminator_model()
generator = generator_model()


generator_containing_discriminator(generator, discriminator).summary()

0 个答案:

没有答案