我已经训练了编码器-解码器模型并将其保存到文件中,但是事实证明我忘了将解码器的LSTM层上的return_state标志设置为true。现在,我无法实现推断。可以在构造函数之外为LSTM设置return_state标志吗?
答案 0 :(得分:0)
我尚未通过再次训练模型来验证以下代码,但这可能会对您有所帮助:
首先,使用tf.keras.models.load_model
方法加载现有的Keras模型:
model = tf.keras.models.load_model('models/model.h5')
如果LSTM
层位于第二个索引处,我们将得到Layer
对象:
lstm = model.layers[2]
lstm
是tf.keras.layers.LSTM()
对象。我们可以修改return_state
参数:
lstm.return_state = True # Set the updated value here
提示:
就像return_state
中的LSTM
一样,我注意到我们还可以修改Keras中所有类型的图层的参数。在这里,我还尝试更改Dense
层的units参数:
dense.units = 23 # Previous value was 64!