NotImplementedError:图层注意在__init__中具有参数,因此必须覆盖get_config。

时间:2020-08-18 21:07:30

标签: python tensorflow save keras-layer

我已实现此链接中建议的自定义注意层: How to add attention layer to a Bi-LSTM

class attention(Layer):    

      def __init__(self, return_sequences=True):
         self.return_sequences = return_sequences
         super(attention,self).__init__()
        
      def build(self, input_shape):
        
        self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
                               initializer="normal")
        self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
                               initializer="zeros")
        
        super(attention,self).build(input_shape)
        
    def call(self, x):
        
        e = K.tanh(K.dot(x,self.W)+self.b)
        a = K.softmax(e, axis=1)
        output = x*a
        
        if self.return_sequences:
            return output
        
        return K.sum(output, axis=1)

代码运行了,但是当需要保存模型时出现了这个错误。

NotImplementedError:图层注意在__init__中具有参数,因此必须覆盖get_config

一些评论建议覆盖get_config。

“此错误使您知道tensorflow无法保存模型,因为它无法加载模型。 具体来说,它将无法重新实例化自定义的Layer类。

要解决此问题,只需根据您添加的新参数覆盖其get_config方法即可。”

查看链接:NotImplementedError: Layers with arguments in `__init__` must override `get_config`

我的问题是,基于上面的自定义关注层,如何编写get_config来解决此错误?

1 个答案:

答案 0 :(得分:0)

您需要这样的配置方法:

def get_config(self):
    config = super().get_config().copy()
    config.update({
        'return_sequences': self.return_sequences 
    })
    return config

所需的所有信息都在您链接的其他post中。