我的网络有2个输入,1个3D灰度体积和1个2D彩色图像。最后,我有一个2D输出,形状与输入图像相同。
我的问题:我想组合2个损失函数,每个网络输入分支一个。一个损失函数应将网络的预测与2D灰度地面真实性进行比较,另一个函数应与2D颜色输入进行比较。这些基本信息存储在单独的文件中。
我不了解y_true的工作原理以及如何告诉它查看我的groundtruth文件。我认为我的问题主要是语法之一,在看了SO上的其他文章后,我感到更加困惑。
这是我想出的代码。当然不行,但是应该让我了解我的目标。
def custom_loss(groundtruth_grayscale, groundtruth_colour):
def loss(y_true, y_pred):
loss_grayscale = ssim(y_pred, groundtruth_grayscale)
loss_colour = ssim(y_pred, groundtruth_colour)
ssim_loss = loss_grayscale + loss_colour
l1_loss_grayscale = l1(y_pred, groundtruth_grayscale)
l1_loss_colour = l1(y_pred, groundtruth_colour)
l1_loss = l1_loss_grayscale + l1_loss_colour
return ssim_loss + l1_loss
return loss
# images_groundtruth_grayscale is a variable containing all groundtruth_grayscale images
model_combined.optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002).minimize(custom_loss(images_groundtruth_grayscale, images_groundtruth_colour), var_list = model_combined.trainable_variables)
如果有帮助,则模型的摘要:
model_combined.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, None, None, 1 9728 input_1[0][0]
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, None, 1 512 conv2d[0][0]
__________________________________________________________________________________________________
conv3d (Conv3D) (None, None, None, N 16128 input_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 6 204864 batch_normalization[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, N 512 conv3d[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256 conv2d_1[0][0]
__________________________________________________________________________________________________
conv3d_1 (Conv3D) (None, None, None, N 1024064 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 3 18464 batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, N 256 conv3d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 3 128 conv2d_2[0][0]
__________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, None, None, N 55328 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 1 4624 batch_normalization_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, N 128 conv3d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 1 64 conv2d_3[0][0]
__________________________________________________________________________________________________
conv3d_3 (Conv3D) (None, None, None, N 13840 batch_normalization_6[0][0]
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, None, None, 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, N 64 conv3d_3[0][0]
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, None, None, 0 tf_op_layer_ExpandDims[0][0]
batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv3d_4 (Conv3D) (None, None, None, N 272 tf_op_layer_add[0][0]
==================================================================================================
Total params: 1,349,232
Trainable params: 1,348,272
Non-trainable params: 960
__________________________________________________________________________________________________