将return_state = True设置为已加载模型中的LSTM层

时间:2019-05-08 21:43:11

标签: tensorflow keras

我已经训练了编码器-解码器模型并将其保存到文件中,但是事实证明我忘了将解码器的LSTM层上的return_state标志设置为true。现在,我无法实现推断。可以在构造函数之外为LSTM设置return_state标志吗?

1 个答案:

答案 0 :(得分:0)

我尚未通过再次训练模型来验证以下代码,但这可能会对您有所帮助:

  1. 首先,使用tf.keras.models.load_model方法加载现有的Keras模型:

    model = tf.keras.models.load_model('models/model.h5')
    
  2. 如果LSTM层位于第二个索引处,我们将得到Layer对象:

    lstm = model.layers[2]
    
  3. lstmtf.keras.layers.LSTM()对象。我们可以修改return_state参数:

    lstm.return_state = True # Set the updated value here
    

提示:

就像return_state中的LSTM一样,我注意到我们还可以修改Keras中所有类型的图层的参数。在这里,我还尝试更改Dense层的units参数:

dense.units = 23 # Previous value was 64!