我要保存的模型中有一个自定义Keras层。我想重新加载此模型。这是我用来执行此操作的代码:
self.model = load_model(path, custom_objects={'MyLayer': MyLayer, 'custom_loss_fn': custom_loss_fn})
这是我在模型中使用的自定义层:
class MyLayer(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super(MyLayer, self).__init__(units, **kwargs)
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, values):
...
def get_config(self):
config = super().get_config()
config.update({
'w1': self.W1,
'w2': self.W2,
'v': self.V,
})
return config
当我尝试加载模型时,出现以下错误:
TypeError: __init__() missing 1 required positional argument: 'units'
那是为什么?
我正在使用kera的ModelCheckpoint
回调保存模型。这可能是兼容性问题吗?
答案 0 :(得分:0)
尝试相应地修改init函数
def __init__(self, units, **kwargs):
self.units = units
...
super(MyCustomLayer, self).__init__(**kwargs)