Keras多输出:不收敛

时间:2017-05-25 01:59:10

标签: model keras loss convergence multipleoutputs

我正在构建一个多输出keras模型

model1 = Model(input=ip, output=[main, aux])
model1.compile(optimizer='sgd', loss={'main':cutom_loss, 'aux':'mean_squared error'}, metrics='accuracy')

model1.fit(input_data, [main_output, aux_output], nb_epoch=epochs, batch_size=batch_size, verbose=2, shuffle=True, validation_split=0.1, callbacks=[checkpointer])

我的custom_loss功能:`

def custom_loss(y_true, y_pred):
    main_pred = y_pred[0]
    main_true = y_true[0]

    loss = K.mean(K.square(main_true - main_pred), axis=-1)
    return loss

但我的网络没有融合

Epoch 1/10
Epoch 00000: val_loss improved from inf to 0.39544, saving model to ./testAE/testAE_best_weights.h5
18s - loss: 0.3896 - main_loss: 0.0449 - aux_loss: 0.3446 - main_acc: 0.0441 - val_loss: 0.3954 - val_main_loss: 0.0510 - val_aux_loss: 0.3445 - val_main_acc: 0.0402
Epoch 2/10
Epoch 00001: val_loss did not improve
18s - loss: 0.3896 - main_loss: 0.0449 - aux_loss: 0.3446 - main_acc: 0.0441 - val_loss: 0.3954 - val_main_loss: 0.0510 - val_aux_loss: 0.3445 - val_main_acc: 0.0402
Epoch 3/10
Epoch 00002: val_loss did not improve
18s - loss: 0.3896 - main_loss: 0.0449 - aux_loss: 0.3446 - main_acc: 0.0441 - val_loss: 0.3954 - val_main_loss: 0.0510 - val_aux_loss: 0.3445 - val_main_acc: 0.0402
Epoch 4/10
Epoch 00003: val_loss did not improve
18s - loss: 0.3896 - main_loss: 0.0449 - aux_loss: 0.3446 - main_acc: 0.0441 - val_loss: 0.3954 - val_main_loss: 0.0510 - val_aux_loss: 0.3445 - val_main_acc: 0.0402

我只想训练主要输出。辅助输出将用于测试。

1 个答案:

答案 0 :(得分:0)

根据所提供的信息,我不清楚您的损失为何没有得到改善,但是我可以解决您的部分问题。我也很困惑,为什么您对使用均方误差的准确性指标感兴趣,但是我不知道您的模型的细节。

请参阅this question,以简单的方式训练您的一个输出(以及有关如何将输出/标签传递给损失函数的说明)。编译模型时,可以使用loss_weights = [1。,0.0]在仅一个输出上训练模型。这样,经过优化的损耗不包括辅助输出。看起来像这样:

model1.compile(optimizer='sgd', loss={'main':custom_loss, 'aux':'mean_squared error'}, 
               metrics='accuracy', loss_weights=[1., 0.0])

由于您只是在计算均方误差,因此将代码重写为

会更简单
model1.compile(optimizer='sgd', loss={'main':'mse', 'aux':'mean_squared error'},
               metrics='accuracy', loss_weights=[1., 0.0])