Keras条件GAN出于某些原因跳过输入层

时间:2019-07-07 13:36:11

标签: keras

我正在尝试使用keras功能API创建条件GAN。出于某种原因,当我将生成器和鉴别器组合在一起时,输入层之一会被跳过,这会带来很多问题。一切都可以正确编译,但是当我尝试在GAN上进行训练时,由于没有为跳过的输入提供占位符,因此出现了错误。我收到的错误是

InvalidArgumentError: You must feed a value for placeholder tensor 'input_6' with dtype float and shape [?,1]
     [[{{node input_6}}]] [Op:StatefulPartitionedCall]

尝试执行此代码时:

g_stats = gan_model.train_on_batch([noise, fake_class_labels], y)

以下是dicrim,gen和gan的代码和摘要,在此先感谢您的帮助!:

def create_generator():

    # Branch 1: Noise
    input_x = Input(shape=(100,)) # Noise
    x = Dense(8*8*256)(input_x)
    x = Reshape(target_shape=(8, 8, 256))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 8x8

    # Branch 2: Class label
    input_y = Input(shape=(1,)) # Class label
    y = Embedding(10, 50)(input_y)
    y = Dense(8*8*1)(y)
    y = Reshape((8, 8, 1))(y)

    merge = Concatenate()([x, y])

    x = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding='same')(merge)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 16x16

    x = Conv2DTranspose(filters=64, kernel_size=5, strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 32x32

    x = Conv2D(filters=32, kernel_size=5, padding='same')(x )
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 32x32


    x = Conv2D(filters=3, kernel_size=5, padding='same')(x)
    z = Activation('sigmoid')(x) # size = 32x32

    net = Model(inputs=[input_x, input_y], outputs=z)

    return net
ARNING:tensorflow:From D:\Applications\Anaconda\envs\ml\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:642: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 16384)        1654784     input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
reshape (Reshape)               (None, 8, 8, 256)    0           dense[0][0]                      
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 50)        500         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_v1 (BatchNo (None, 8, 8, 256)    1024        reshape[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1, 64)        3264        embedding[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 8, 8, 256)    0           batch_normalization_v1[0][0]     
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 8, 8, 1)      0           dense_1[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 8, 8, 257)    0           leaky_re_lu[0][0]                
                                                                 reshape_1[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 16, 16, 128)  822528      concatenate[0][0]                
__________________________________________________________________________________________________
batch_normalization_v1_1 (Batch (None, 16, 16, 128)  512         conv2d_transpose[0][0]           
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 16, 16, 128)  0           batch_normalization_v1_1[0][0]   
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 32, 32, 64)   204864      leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_2 (Batch (None, 32, 32, 64)   256         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 32, 32, 64)   0           batch_normalization_v1_2[0][0]   
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 32)   51232       leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_3 (Batch (None, 32, 32, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 32, 32, 32)   0           batch_normalization_v1_3[0][0]   
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 3)    2403        leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
activation (Activation)         (None, 32, 32, 3)    0           conv2d_1[0][0]                   
==================================================================================================
Total params: 2,741,495
Trainable params: 2,740,535
Non-trainable params: 960
def create_discriminator():

    # Branch 1: Input (image)
    input_x = Input(shape=(32, 32, 3))

    # Branch 2: Class label
    input_y = Input(shape=(1,)) # Class label
    y = Embedding(10, 1024)(input_y)
    y = Dense(1024)(y)
    y = Reshape((32, 32, 1))(y)

    merge = Concatenate()([input_x, y])


    x = Conv2D(filters=32, kernel_size=5, padding='same')(merge)
    x = LeakyReLU(0.2)(x) # size = 32x32

    x = Conv2D(filters=64, kernel_size=5, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 32x32

    x = Conv2D(filters=128, kernel_size=5, strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 16x16

    x = Conv2D(filters=256, kernel_size=5, strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x) # size = 8x8

    x = Flatten()(x)
    x = Dense(1)(x)
    z = Activation('sigmoid')(x)

    net = Model(inputs=[input_x, input_y], outputs=z)

    optim = tf.keras.optimizers.Adam(lr=0.000008, decay=1e-10)
    net.compile(optimizer=optim, loss='binary_crossentropy', metrics=['accuracy'])

    return net
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 1024)      10240       input_4[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1, 1024)      1049600     embedding_1[0][0]                
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 32, 32, 1)    0           dense_2[0][0]                    
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 4)    0           input_3[0][0]                    
                                                                 reshape_2[0][0]                  
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 32, 32, 32)   3232        concatenate_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 32, 32, 32)   0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 32, 32, 64)   51264       leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_4 (Batch (None, 32, 32, 64)   256         conv2d_3[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 32, 32, 64)   0           batch_normalization_v1_4[0][0]   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 16, 16, 128)  204928      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_5 (Batch (None, 16, 16, 128)  512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 16, 16, 128)  0           batch_normalization_v1_5[0][0]   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 8, 8, 256)    819456      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_6 (Batch (None, 8, 8, 256)    1024        conv2d_5[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 8, 8, 256)    0           batch_normalization_v1_6[0][0]   
__________________________________________________________________________________________________
flatten (Flatten)               (None, 16384)        0           leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            16385       flatten[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 1)            0           dense_3[0][0]                    
==================================================================================================
Total params: 2,156,897
Trainable params: 2,156,001
Non-trainable params: 896
def create_gan(d_model, g_model):
    d_model.trainable = False
    gen_noise, gen_label = g_model.input
    gen_output = g_model.output
    gan_output = d_model([gen_output, gen_label])
    gan_model = Model([gen_noise, gen_label], gan_output)
    gan_optim = tf.keras.optimizers.Adam(lr=0.00004, decay=1e-10)
    gan_model.compile(optimizer=gan_optim, loss='binary_crossentropy', metrics=['accuracy'])
    return gan_model

d_model = create_discriminator()

g_model = create_generator()

gan_model = create_gan(d_model, g_model)
gan_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 16384)        1654784     input_7[0][0]                    
__________________________________________________________________________________________________
input_8 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 8, 8, 256)    0           dense_6[0][0]                    
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 1, 50)        500         input_8[0][0]                    
__________________________________________________________________________________________________
batch_normalization_v1_10 (Batc (None, 8, 8, 256)    1024        reshape_4[0][0]                  
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 1, 64)        3264        embedding_3[0][0]                
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, 8, 8, 256)    0           batch_normalization_v1_10[0][0]  
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 8, 8, 1)      0           dense_7[0][0]                    
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 8, 8, 257)    0           leaky_re_lu_12[0][0]             
                                                                 reshape_5[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 16, 16, 128)  822528      concatenate_3[0][0]              
__________________________________________________________________________________________________
batch_normalization_v1_11 (Batc (None, 16, 16, 128)  512         conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, 16, 16, 128)  0           batch_normalization_v1_11[0][0]  
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 32, 32, 64)   204864      leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
batch_normalization_v1_12 (Batc (None, 32, 32, 64)   256         conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, 32, 32, 64)   0           batch_normalization_v1_12[0][0]  
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 32)   51232       leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_v1_13 (Batc (None, 32, 32, 32)   128         conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, 32, 32, 32)   0           batch_normalization_v1_13[0][0]  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 3)    2403        leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 32, 32, 3)    0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
model_2 (Model)                 (None, 1)            2156897     activation_3[0][0]               
                                                                 input_8[0][0]                    
==================================================================================================
Total params: 4,898,392
Trainable params: 2,740,535
Non-trainable params: 2,157,857

0 个答案:

没有答案