使用自定义方法保存/加载Keras模型

时间:2020-06-04 00:24:47

标签: python tensorflow keras

我正在尝试构建一个可分为两部分的NN,其中每一部分都可以独立运行。子类化keras Model类可以很好地实现这一点,如以下玩具模型所示:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs):
        intermediate_value = self.model_part_1(inputs)
        final_output = self.model_part_2(intermediate_value)
        return final_output

    def model_part_1(self, inputs):
        x = self.dense1(inputs)
        return x

    def model_part_2(self, inputs):
        x = self.dense2(inputs) 
        return x

除了自定义方法不会通过保存/加载进行之外,所有这些都很好地工作。使用标准的model.save("saved_model_path"),然后使用tf.keras.models.load_model("saved_model")加载,加载的模型对象在运行predict时可以按预期工作,但不再具有model_part_1model_part_2方法(属性密实1和密实2已正确加载)。

在加载时添加关键字参数custom_objects={"MyModel": MyModel}不能解决问题

应该可以将方法添加到已加载的实例中,但这确实很麻烦。

1 个答案:

答案 0 :(得分:0)

我能够通过用tf.function装饰功能来解决此问题:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs):
        intermediate_value = self._model_part_1(inputs)
        final_output = self._model_part_2(intermediate_value)
        return final_output

    def _model_part_1(self, inputs):
        x = self.dense1(inputs)
        return x

    def _model_part_2(self, inputs):
        x = self.dense2(inputs) 
        return x

    @tf.function(
        input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
    )
    def model_part_1(self, inputs):
        """ tf.function-deocrated version of _model_part_1 """
        return self._model_part_1(inputs)

    @tf.function(
        input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
    )
    def model_part_2(self, inputs):
        """ tf.function-deocrated version of _model_part_2 """
        return self._model_part_2(inputs
        )


使用.save()方法保存并使用tf.keras.models.load_model加载后,可以使用修饰的方法。

请注意,我使用装饰器创建了新功能;这是因为在call方法中调用修饰函数会导致错误。