使用自定义图层加载模型时的Keras TypeError

时间:2020-09-18 10:21:39

标签: python tensorflow keras keras-layer

我要保存的模型中有一个自定义Keras层。我想重新加载此模型。这是我用来执行此操作的代码:

self.model = load_model(path, custom_objects={'MyLayer': MyLayer, 'custom_loss_fn': custom_loss_fn})

这是我在模型中使用的自定义层:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super(MyLayer, self).__init__(units, **kwargs)
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, values):
        ...

    def get_config(self):
        config = super().get_config()
        config.update({
          'w1': self.W1,
          'w2': self.W2,
          'v': self.V,
        })
        return config

当我尝试加载模型时,出现以下错误:

TypeError: __init__() missing 1 required positional argument: 'units'

那是为什么?

更新

我正在使用kera的ModelCheckpoint回调保存模型。这可能是兼容性问题吗?

1 个答案:

答案 0 :(得分:0)

尝试相应地修改init函数

def __init__(self, units, **kwargs):    
   self.units = units
   ...
   super(MyCustomLayer, self).__init__(**kwargs)