keras多输出:一种损失取决于另一种

时间:2018-08-10 19:15:19

标签: keras

我正在Keras培训一个cGAN,它具有两个输出:

1:真实/伪造(1/0)

2:目标标签(大小为10的向量(我将输出分类为10类))

对于真实/伪造,我使用二进制交叉熵作为损失函数

对于目标标签,当实/假标签为1(真实)时,我想使用分类交叉熵;当实/假标签为0(假)时,loss = 0。

因此,一个输出的损耗函数取决于另一输出的基本事实。我一直在尝试不同的方法,但是似乎都没有用。有人可以帮我吗?

这是网络:

def Discriminator():
  input_shape = (96,96,3)
  inputs = Input(shape=(input_shape))

  x1 = Convolution2D(32,(3,3), strides=(1,1), padding='same')(inputs)
  x1 = BatchNormalization()(x1)
  x1 = LeakyReLU(0.2)(x1)
  x1 = Convolution2D(32,(3,3), strides=(2,2), padding='same')(x1)
  x1 = BatchNormalization()(x1)
### More Conv,BN and Activation layers     
  x1 = Flatten()(x1)
  x1 = Dropout(0.5)(x1)
  x2 = Dense(1)(x1)
  x2 = Activation('sigmoid', name="real_fake")(x2)
  x3 = Dense(10)(x1)
  x3 = Activation('softmax', name="class_labels")(x3)
  mdl = Model(inputs=inputs, outputs=[x2,x3])
  return mdl

def Generator():
  inputs = Input(shape=(7,7,2048))
  img = Input(shape=(96,96,3))
  labels = Input(shape=(1,10))

  t = Dense(256, activation = 'relu')(labels)
  t1 = Reshape((16,16,1))(t)
  t1 = Convolution2D(8,(3,3), strides=(2,2), padding='valid')(t1)     

  x = Concatenate(axis=3)([inputs,t1])
  x = UpSampling2D((2, 2))(x)
  x = Convolution2D(512, kernel_size=(3, 3), padding='valid')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
## More UpSam, Conv, BN and Activation
  x = Activation('tanh')(x)
  x = Add()([x,img])
  model = Model(inputs=[inputs,img,labels] , outputs=x)
  return model

def GAN(generator, discriminator):
  inputs = Input(shape=(7,7,2048))
  img = Input(shape=(96,96,3))
  labels = Input(shape=(1,10))
  x = generator([inputs, img, labels])
  discriminator.trainable = False
  y1,y2 = discriminator(x)
  mdl = Model([inputs,img,labels], [y1,y2])
  return mdl

discriminator = Discriminator()
generator = Generator()
gan = GAN(generator, discriminator)
g_optimizer = Adam(lr=2e-4, beta_1=0.5)
generator.compile(loss='binary_crossentropy', optimizer=g_optimizer)
gan.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=g_optimizer)
d_optimizer = Adam(lr=1e-5, beta_1=0.1)
discriminator.trainable = True
discriminator.compile(loss=['binary_crossentropy', CUSTOM_LOSS], optimizer=d_optimizer)

基本上,当对鉴别器进行真实或伪造训练时,当输入是真实图像而不是生成器生成的伪造时,我希望鉴别器针对目标类标签进行训练。所以我想写一个CUSTOM_LOSS,当Discrimnator得到一个真实的图像时为cce,当它为伪造的时候为零。

谢谢

0 个答案:

没有答案