我正在尝试使用keras功能API创建条件GAN。出于某种原因,当我将生成器和鉴别器组合在一起时,输入层之一会被跳过,这会带来很多问题。一切都可以正确编译,但是当我尝试在GAN上进行训练时,由于没有为跳过的输入提供占位符,因此出现了错误。我收到的错误是
InvalidArgumentError: You must feed a value for placeholder tensor 'input_6' with dtype float and shape [?,1]
[[{{node input_6}}]] [Op:StatefulPartitionedCall]
尝试执行此代码时:
g_stats = gan_model.train_on_batch([noise, fake_class_labels], y)
以下是dicrim,gen和gan的代码和摘要,在此先感谢您的帮助!:
def create_generator():
# Branch 1: Noise
input_x = Input(shape=(100,)) # Noise
x = Dense(8*8*256)(input_x)
x = Reshape(target_shape=(8, 8, 256))(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 8x8
# Branch 2: Class label
input_y = Input(shape=(1,)) # Class label
y = Embedding(10, 50)(input_y)
y = Dense(8*8*1)(y)
y = Reshape((8, 8, 1))(y)
merge = Concatenate()([x, y])
x = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding='same')(merge)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 16x16
x = Conv2DTranspose(filters=64, kernel_size=5, strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 32x32
x = Conv2D(filters=32, kernel_size=5, padding='same')(x )
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 32x32
x = Conv2D(filters=3, kernel_size=5, padding='same')(x)
z = Activation('sigmoid')(x) # size = 32x32
net = Model(inputs=[input_x, input_y], outputs=z)
return net
ARNING:tensorflow:From D:\Applications\Anaconda\envs\ml\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:642: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
dense (Dense) (None, 16384) 1654784 input_1[0][0]
__________________________________________________________________________________________________
input_2 (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
reshape (Reshape) (None, 8, 8, 256) 0 dense[0][0]
__________________________________________________________________________________________________
embedding (Embedding) (None, 1, 50) 500 input_2[0][0]
__________________________________________________________________________________________________
batch_normalization_v1 (BatchNo (None, 8, 8, 256) 1024 reshape[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 1, 64) 3264 embedding[0][0]
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 8, 8, 256) 0 batch_normalization_v1[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 8, 8, 1) 0 dense_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 8, 8, 257) 0 leaky_re_lu[0][0]
reshape_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 16, 16, 128) 822528 concatenate[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_1 (Batch (None, 16, 16, 128) 512 conv2d_transpose[0][0]
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 128) 0 batch_normalization_v1_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 32, 32, 64) 204864 leaky_re_lu_1[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_2 (Batch (None, 32, 32, 64) 256 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 32, 32, 64) 0 batch_normalization_v1_2[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 32, 32, 32) 51232 leaky_re_lu_2[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_3 (Batch (None, 32, 32, 32) 128 conv2d[0][0]
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 32) 0 batch_normalization_v1_3[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 32, 32, 3) 2403 leaky_re_lu_3[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 32, 32, 3) 0 conv2d_1[0][0]
==================================================================================================
Total params: 2,741,495
Trainable params: 2,740,535
Non-trainable params: 960
def create_discriminator():
# Branch 1: Input (image)
input_x = Input(shape=(32, 32, 3))
# Branch 2: Class label
input_y = Input(shape=(1,)) # Class label
y = Embedding(10, 1024)(input_y)
y = Dense(1024)(y)
y = Reshape((32, 32, 1))(y)
merge = Concatenate()([input_x, y])
x = Conv2D(filters=32, kernel_size=5, padding='same')(merge)
x = LeakyReLU(0.2)(x) # size = 32x32
x = Conv2D(filters=64, kernel_size=5, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 32x32
x = Conv2D(filters=128, kernel_size=5, strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 16x16
x = Conv2D(filters=256, kernel_size=5, strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x) # size = 8x8
x = Flatten()(x)
x = Dense(1)(x)
z = Activation('sigmoid')(x)
net = Model(inputs=[input_x, input_y], outputs=z)
optim = tf.keras.optimizers.Adam(lr=0.000008, decay=1e-10)
net.compile(optimizer=optim, loss='binary_crossentropy', metrics=['accuracy'])
return net
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
embedding_1 (Embedding) (None, 1, 1024) 10240 input_4[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1, 1024) 1049600 embedding_1[0][0]
__________________________________________________________________________________________________
input_3 (InputLayer) (None, 32, 32, 3) 0
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 32, 32, 1) 0 dense_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 32, 32, 4) 0 input_3[0][0]
reshape_2[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 32, 32, 32) 3232 concatenate_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 32, 32, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 32, 32, 64) 51264 leaky_re_lu_4[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_4 (Batch (None, 32, 32, 64) 256 conv2d_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 32, 32, 64) 0 batch_normalization_v1_4[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 128) 204928 leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_5 (Batch (None, 16, 16, 128) 512 conv2d_4[0][0]
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 16, 16, 128) 0 batch_normalization_v1_5[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 8, 8, 256) 819456 leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_6 (Batch (None, 8, 8, 256) 1024 conv2d_5[0][0]
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 8, 8, 256) 0 batch_normalization_v1_6[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 16384) 0 leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 1) 16385 flatten[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 1) 0 dense_3[0][0]
==================================================================================================
Total params: 2,156,897
Trainable params: 2,156,001
Non-trainable params: 896
def create_gan(d_model, g_model):
d_model.trainable = False
gen_noise, gen_label = g_model.input
gen_output = g_model.output
gan_output = d_model([gen_output, gen_label])
gan_model = Model([gen_noise, gen_label], gan_output)
gan_optim = tf.keras.optimizers.Adam(lr=0.00004, decay=1e-10)
gan_model.compile(optimizer=gan_optim, loss='binary_crossentropy', metrics=['accuracy'])
return gan_model
d_model = create_discriminator()
g_model = create_generator()
gan_model = create_gan(d_model, g_model)
gan_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_7 (InputLayer) (None, 100) 0
__________________________________________________________________________________________________
dense_6 (Dense) (None, 16384) 1654784 input_7[0][0]
__________________________________________________________________________________________________
input_8 (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
reshape_4 (Reshape) (None, 8, 8, 256) 0 dense_6[0][0]
__________________________________________________________________________________________________
embedding_3 (Embedding) (None, 1, 50) 500 input_8[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_10 (Batc (None, 8, 8, 256) 1024 reshape_4[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 1, 64) 3264 embedding_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 8, 8, 256) 0 batch_normalization_v1_10[0][0]
__________________________________________________________________________________________________
reshape_5 (Reshape) (None, 8, 8, 1) 0 dense_7[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 8, 8, 257) 0 leaky_re_lu_12[0][0]
reshape_5[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 16, 16, 128) 822528 concatenate_3[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_11 (Batc (None, 16, 16, 128) 512 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 16, 16, 128) 0 batch_normalization_v1_11[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 32, 32, 64) 204864 leaky_re_lu_13[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_12 (Batc (None, 32, 32, 64) 256 conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 32, 32, 64) 0 batch_normalization_v1_12[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 32, 32, 32) 51232 leaky_re_lu_14[0][0]
__________________________________________________________________________________________________
batch_normalization_v1_13 (Batc (None, 32, 32, 32) 128 conv2d_10[0][0]
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None, 32, 32, 32) 0 batch_normalization_v1_13[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 3) 2403 leaky_re_lu_15[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 32, 32, 3) 0 conv2d_11[0][0]
__________________________________________________________________________________________________
model_2 (Model) (None, 1) 2156897 activation_3[0][0]
input_8[0][0]
==================================================================================================
Total params: 4,898,392
Trainable params: 2,740,535
Non-trainable params: 2,157,857