无法加载keras训练的模型

时间:2019-03-18 09:31:50

标签: python keras deep-learning text-classification

我正在使用以下代码来训练HAN Network。 Code Link

我已经成功地训练了模型,但是当我尝试使用keras load_model加载模型时,出现以下错误- 未知层:AttentionWithContext

2 个答案:

答案 0 :(得分:0)

根据您共享的链接,模型在模型中添加了一个明确定义的图层AttentionWithContext()。当您尝试使用keras的load_model加载模型时,该方法会显示错误,因为该层不是keras内置的,要解决此问题,您可能必须在代码中再次定义该层,然后再使用load_model加载模型。在尝试加载模型之前,请尝试在提供的链接(https://www.kaggle.com/hsankesara/news-classification-using-han/notebook)中编写类AttentionWithContext(layer)。

答案 1 :(得分:0)

在AttentionWithContext.py文件中添加以下功能:

def create_custom_objects():
    instance_holder = {"instance": None}

    class ClassWrapper(AttentionWithContext):
        def __init__(self, *args, **kwargs):
            instance_holder["instance"] = self
            super(ClassWrapper, self).__init__(*args, **kwargs)

    def loss(*args):
        method = getattr(instance_holder["instance"], "loss_function")
        return method(*args)

    def accuracy(*args):
        method = getattr(instance_holder["instance"], "accuracy")
        return method(*args)
    return {"ClassWrapper": ClassWrapper ,"AttentionWithContext": ClassWrapper, "loss": loss,
            "accuracy":accuracy}

加载模型时:

from AttentionWithContext import create_custom_objects

model = keras.models.load_model(model_path, custom_objects=create_custom_objects())

model.evaluate(X_test, y_test) # or model.predict