仅在Keras中的多输出模型训练期间显示总损失

时间:2019-12-14 16:14:21

标签: python tensorflow machine-learning keras neural-network

我正在通过Keras的功能API实现自动编码器模型。我的模型是多输出的,结果是在每个输出上都评估了损失函数。在训练过程中,将这些损失的加权总和最小化:

losses = [jsd for j in range(m)]  # JSD loss function for each output
autoencoder = Model(inputs, decodes)
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
autoencoder.compile(optimizer=sgd, loss=losses, loss_weights=[1 for k in range(m)]) # each output has the same priority

然后我将模型拟合为训练数据,并根据测试数据进行评估:

history = autoencoder.fit(train_corr, train_attr_corr, epochs=50, batch_size=10, shuffle=True, verbose=2,
                          validation_data=(test_corr, test_attr_GT))

verbose=2一样,训练和验证损失将在每个时期结束时显示在控制台中。但是,由于模型是多输出的,因此将显示所有“子损失”。例如:

Epoch 1/50
 - 3s - loss: 0.3356 - dense_4_loss: 0.0647 - dense_5_loss: 0.0436 - dense_6_loss: 0.0391 - dense_7_loss: 0.0378 - dense_8_loss: 0.0250 - dense_9_loss: 0.0362 - val_loss: 0.1067 - val_dense_4_loss: 0.0101 - val_dense_5_loss: 0.0042 - val_dense_6_loss: 0.0031 - val_dense_7_loss: 0.0036 - val_dense_8_loss: 0.0041 - val_dense_9_loss: 0.0066

问题: 是否可以仅显示每个时期的总训练损失(loss)和总验证损失val_loss

修改: 在上面的示例中,我只想显示loss: 0.3356val_loss: 0.1067

1 个答案:

答案 0 :(得分:1)

无法使用Keras model.fit函数中的默认详细选项。但是,您可以使用自定义回调来实现。使用verbosity=0在fit函数中禁用详细程度。定义以下回调函数,该回调函数将重写默认回调,并在时期开始和结束时修改结果。

class PrinterCallback(tf.keras.callbacks.Callback):

    # def on_train_batch_begin(self, batch, logs=None):
    #     # Do something on begin of training batch

    def on_epoch_end(self, epoch, logs=None):
        print('EPOCH: {}, Train Loss: {}, Val Loss: {}'.format(epoch,
                                                               logs['loss'],
                                                               logs['val_loss']))

    def on_epoch_begin(self, epoch, logs=None):
        print('-'*50)
        print('STARTING EPOCH: {}'.format(epoch))

    # def on_train_batch_end(self, batch, logs=None):
    #     # Do something on end of training batch
    #

在调用model.fit时,将此回调用作callback=[PrinterCallback()]。还有其他功能也可以在此处进行操作。例如,您可以在火车开始时做些什么,等等(代码中显示的示例对)。随意修改所需值的打印方式,例如,控制小数位。

有关Keras回调的详细信息,请访问here,您还可以检查其他回调的源代码以实现自己的回调。

希望有帮助!