预训练模型冻结的keras访问层参数

时间:2018-07-19 17:29:07

标签: python machine-learning keras lstm keras-layer

我保存了多层LSTM。现在,我想加载它并微调最后一个LSTM层。如何定位该层并更改其参数?

经过训练并保存的简单模型示例:

model = Sequential()
# first layer  #neurons 
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1], 
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')

我可以加载和重新训练它,但找不到针对特定层并冻结所有其他层的方法。

2 个答案:

答案 0 :(得分:1)

一个简单的解决方案是命名每个层,即

model.add(LSTM(50, return_sequences=True, name='2nd_lstm'))

然后,在加载模型后,您可以遍历图层并冻结与名称条件匹配的图层:

for layer in model.layers:
    if layer.name == '2nd_lstm':
        layer.trainable = False

然后,您需要重新编译模型,以使更改生效,然后,您可以像往常一样恢复训练。

答案 1 :(得分:0)

如果您以前已经建立并保存了模型,现在想加载它并仅微调最后一个LSTM层,则需要将其他层的trainable属性设置为False。首先,使用model.summary()方法来查找层的名称(或从顶部开始从零开始计数层的索引)。例如,这是为我的一个模型产生的输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        (None, 400, 16)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 400, 32)           4128      
_________________________________________________________________
lstm_2 (LSTM)                (None, 32)                8320      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 33        
=================================================================
Total params: 12,481
Trainable params: 12,481
Non-trainable params: 0
_________________________________________________________________

然后将LSTM层以外的所有层的可训练参数设置为False

方法1:

for layer in model.layers:
    if layer.name != `lstm_2`
        layer.trainable = False

方法2:

for layer in model.layers:
    layer.trainable = False

model.layers[2].trainable = True  # set lstm to be trainable

# to make sure 2 is the index of the layer
print(model.layers[2].name)    # prints 'lstm_2'

别忘了再次编译模型以应用这些更改。