如何恢复LSTM模型的零件参数

时间:2016-07-23 15:08:22

标签: tensorflow

我已经训练了LSTM模型L1用于预训练,参数为:W-lstm,b-lstm,Wy,by等。

我想用不同的Wy和by来训练一个新的LSTM模型L2,假设是Wy2和by2。

那么,如何只从L1恢复W-lstm和b-lstm,并启动Wy2和by2?然后训练。

saver.restore()?

感谢。

1 个答案:

答案 0 :(得分:-1)

创建保护程序时,可以向其传递要保存的变量列表。因此,您可以为第一个LSTM创建一个保护程序。这是我的意思的一个例子。 (请注意,这只是我在浏览器中键入的代码,因此您可能会发现拼写错误)只需在变量范围内创建LSTM并创建一个只保存该范围内变量的保护程序。

lstm = tf.nn.rnn_cell.BasicLSTMCel(1024)
with tf.variable_scope('L1') as scope:
    state = INITIAL_STATE
    for t in range(TIME_STEPS):
        output, state = lstm(input, state)
        scope.reuse_variables()

...

l1_saver = tf.train.Saver([v for v in tf.all_variables() if 'L1' in v.name])