我正在Keras培训一个cGAN,它具有两个输出:
1:真实/伪造(1/0)
2:目标标签(大小为10的向量(我将输出分类为10类))
对于真实/伪造,我使用二进制交叉熵作为损失函数
对于目标标签,当实/假标签为1(真实)时,我想使用分类交叉熵;当实/假标签为0(假)时,loss = 0。
因此,一个输出的损耗函数取决于另一输出的基本事实。我一直在尝试不同的方法,但是似乎都没有用。有人可以帮我吗?
这是网络:
def Discriminator():
input_shape = (96,96,3)
inputs = Input(shape=(input_shape))
x1 = Convolution2D(32,(3,3), strides=(1,1), padding='same')(inputs)
x1 = BatchNormalization()(x1)
x1 = LeakyReLU(0.2)(x1)
x1 = Convolution2D(32,(3,3), strides=(2,2), padding='same')(x1)
x1 = BatchNormalization()(x1)
### More Conv,BN and Activation layers
x1 = Flatten()(x1)
x1 = Dropout(0.5)(x1)
x2 = Dense(1)(x1)
x2 = Activation('sigmoid', name="real_fake")(x2)
x3 = Dense(10)(x1)
x3 = Activation('softmax', name="class_labels")(x3)
mdl = Model(inputs=inputs, outputs=[x2,x3])
return mdl
def Generator():
inputs = Input(shape=(7,7,2048))
img = Input(shape=(96,96,3))
labels = Input(shape=(1,10))
t = Dense(256, activation = 'relu')(labels)
t1 = Reshape((16,16,1))(t)
t1 = Convolution2D(8,(3,3), strides=(2,2), padding='valid')(t1)
x = Concatenate(axis=3)([inputs,t1])
x = UpSampling2D((2, 2))(x)
x = Convolution2D(512, kernel_size=(3, 3), padding='valid')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
## More UpSam, Conv, BN and Activation
x = Activation('tanh')(x)
x = Add()([x,img])
model = Model(inputs=[inputs,img,labels] , outputs=x)
return model
def GAN(generator, discriminator):
inputs = Input(shape=(7,7,2048))
img = Input(shape=(96,96,3))
labels = Input(shape=(1,10))
x = generator([inputs, img, labels])
discriminator.trainable = False
y1,y2 = discriminator(x)
mdl = Model([inputs,img,labels], [y1,y2])
return mdl
discriminator = Discriminator()
generator = Generator()
gan = GAN(generator, discriminator)
g_optimizer = Adam(lr=2e-4, beta_1=0.5)
generator.compile(loss='binary_crossentropy', optimizer=g_optimizer)
gan.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=g_optimizer)
d_optimizer = Adam(lr=1e-5, beta_1=0.1)
discriminator.trainable = True
discriminator.compile(loss=['binary_crossentropy', CUSTOM_LOSS], optimizer=d_optimizer)
基本上,当对鉴别器进行真实或伪造训练时,当输入是真实图像而不是生成器生成的伪造时,我希望鉴别器针对目标类标签进行训练。所以我想写一个CUSTOM_LOSS,当Discrimnator得到一个真实的图像时为cce,当它为伪造的时候为零。
谢谢