如何找到LSTMCell权重和偏差名称以将值加载到它们中?

时间:2017-03-28 16:13:32

标签: tensorflow

我一直在尝试将一些改变的权重加载到Tensorflow计算图中,同时训练多层LSTM RNN。使用以下代码行:

variables_names =[v.name for v in tf.trainable_variables()]
values = session.run(variables_names)

给出了我使用的变量的名称和值,LSTMCell的权重名称是

rnn/multi_rnn_cell/cell_0/lstm_cell/weights:0
rnn/multi_rnn_cell/cell_1/lstm_cell/weights:0

等等,但我不能直接在

中使用上述名称
rnn/multi_rnn_cell/cell_0/lstm_cell/weights.load(values[0], session)

将值加载回来的方法。有谁知道如何将新的重量加载到LSTM细胞中?

1 个答案:

答案 0 :(得分:3)

将变量名称转换为tf.Variable对象的最简单方法是过滤tf.trainable_variables(),匹配名称:

cell_0_weights = [v for v in tf.trainable_variables()
                  if v.name == 'rnn/multi_rnn_cell/cell_0/lstm_cell/weights:0'][0]

(这不是特别有效,但变量集通常很小,效率低下并不重要。)

拥有tf.Variable对象后,您可以使用其load()方法分配新权重:

cell_0_weights.load(values[0], sess)