加载其中具有自定义Attention层的keras模型时出现Unknown Layer错误

时间:2018-08-10 22:51:21

标签: python tensorflow keras

我有一个经过keras训练的模型,其中包含自定义关注层。我在训练时首先尝试通过检查站进行保存。然后在加载时出现了一些未知的图层错误。接下来,我尝试将模型以jason格式保存,并将权重保存在h5文件中。接下来,在尝试加载模型时,它会显示相同的错误以及以下错误:

Traceback (most recent call last):
  File "/home/khaledkucse/Project/python/DeepMethodSequenceReccomender/__main__.py", line 48, in <module>
    loaded_model = model_from_json(loaded_model_json,custom_objects={"AttentionLayer":AttentionMechanism.AttentionL})
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/engine/saving.py", line 368, in model_from_json
    return deserialize(config, custom_objects=custom_objects)
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/layers/__init__.py", line 60, in deserialize
    printable_module_name='layer')
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/engine/network.py", line 1017, in from_config
    process_layer(layer_data)
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/engine/network.py", line 1003, in process_layer
    custom_objects=custom_objects)
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/layers/__init__.py", line 60, in deserialize
    printable_module_name='layer')
  File "/home/khaledkucse/.local/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 138, in deserialize_keras_object
    ': ' + class_name)
ValueError: Unknown layer: AttentionL

我在model_from_json函数中使用了自定义对象参数。这是代码:

json_file = open(config.model_file_JSON_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json,custom_objects={"AttentionLayer":AttentionMechanism.AttentionL})
loaded_model.load_weights(config.model_file_path)

我的自定义层中有get_config()方法。在这里:

   def get_config(self):
    config={'step_dim':self.step_dim}
    base_config = super(AttentionL, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

现在仍然没有运气。我如何加载我的模型。请注意,我只能保存和加载权重,但在预测测试用例时会为双向LSTM产生未初始化的错误。

0 个答案:

没有答案