当模型具有Attention层时,无法从Model.get_config()从keras中加载模型

时间:2019-09-22 14:29:15

标签: python tensorflow keras

我正在从某个函数创建的配置字典中在keras中加载模型。我已经在很多其他模型上尝试了该方案,但都没有问题,但这是我使用tensorflow.keras.layers.Attention的第一个方案,并且在读取配置时遇到了Unknown Layer异常。

我知道有一个使用JSON / YAML序列化和加载自定义层的API,但这是一个keras层,我做错了吗?

使用Tensorflow 1.14.0的方式

from tensorflow.keras import layers, models, utils

def my_model(max_len, vocab_size, embedding_dims):
    sequence = layers.Input(shape=(max_len,), name='sequence')
    feature = layers.Input(shape=(1,), name='another_feature')

    x = layers.Embedding(input_dim=vocab_size,
                         output_dim=embedding_dims,
                         input_length=max_len)(sequence)
    out, sh, sc = layers.LSTM(64, return_state=True)(x)
    att = layers.Attention()([out, sh])
    x = layers.concatenate([att, feature])
    model = models.Model(inputs=[sequence, feature], outputs=[x])
    model.summary()
    return model.get_config()

title_max_len = 50
vocab_size = 35000
embedding_dims = 30

config = my_model(
    title_max_len, 
    vocab_size, 
    embedding_dims
)
model = models.Model.from_config(config)  # Unknown layer: Attention
utils.plot_model(
    model, 
    show_shapes=True, 
    show_layer_names=True, 
    to_file='model.png'
)

1 个答案:

答案 0 :(得分:1)

Tensorflow 1.14似乎存在问题,但在1.15中它没有任何问题。如果您能够升级,那可能是最简单的解决方案。