如何使用自定义图层正确加载模型?

时间:2020-06-16 22:20:55

标签: python tensorflow keras keras-layer tf.keras

我正在尝试使用自定义图层加载模型:

siamese_model = load_model(path, custom_objects={'siamese_loss': SIAMESE_LOSS})

通过传递的字典,模型应该可以成功加载,但是错误仍然会弹出:

ValueError: Unknown layer: SIAMESE_LOSS

自定义层的代码:

class SIAMESE_LOSS(Layer):
    def __init__(self, **kwargs):
        super(SIAMESE_LOSS, self).__init__(**kwargs)

    @staticmethod
    def mmd_loss(source_samples, target_samples):
        return mmd(source_samples, target_samples)

    @staticmethod
    def regression_loss(pred, labels):
        return K.mean(mae(pred, labels))

    @staticmethod
    def regression_mse(pred, labels):
        return K.mean(mse(pred, labels))

    def call(self, inputs, **kwargs):
        source_labels = inputs[0]
        target_labels = inputs[1]
        source_pred = inputs[2]
        target_pred = inputs[3]
        source_samples = inputs[4]
        target_samples = inputs[5]

        source_loss = self.regression_loss(source_pred, source_labels)
        target_loss = self.regression_loss(target_pred, target_labels)
        mmd_loss = self.mmd_loss(source_samples, target_samples)
        total_loss = source_loss + target_loss + mmd_loss

        source_mse = self.regression_mse(source_pred, source_labels)
        target_mse = self.regression_mse(target_pred, target_labels)

        self.add_loss(total_loss, inputs=True)
        self.add_metric(target_loss, aggregation='mean', name='target_mae')
        self.add_metric(source_loss, aggregation='mean', name='source_mae')
        self.add_metric(mmd_loss, aggregation='mean', name='MMD')
        self.add_metric(target_mse, aggregation='mean', name='target_mse')
        self.add_metric(source_mse, aggregation='mean', name='source_mse')
        return inputs[2], inputs[3]

    def get_config(self, **kwargs):
        super(SIAMESE_LOSS, self).get_config(**kwargs)

真正重要的是,我在训练模型时没有重写get_config()方法。这是我出现问题的原因吗?

0 个答案:

没有答案