我一直在尝试将一些改变的权重加载到Tensorflow计算图中,同时训练多层LSTM RNN。使用以下代码行:
variables_names =[v.name for v in tf.trainable_variables()]
values = session.run(variables_names)
给出了我使用的变量的名称和值,LSTMCell的权重名称是
rnn/multi_rnn_cell/cell_0/lstm_cell/weights:0
rnn/multi_rnn_cell/cell_1/lstm_cell/weights:0
等等,但我不能直接在
中使用上述名称rnn/multi_rnn_cell/cell_0/lstm_cell/weights.load(values[0], session)
将值加载回来的方法。有谁知道如何将新的重量加载到LSTM细胞中?
答案 0 :(得分:3)
将变量名称转换为tf.Variable
对象的最简单方法是过滤tf.trainable_variables()
,匹配名称:
cell_0_weights = [v for v in tf.trainable_variables()
if v.name == 'rnn/multi_rnn_cell/cell_0/lstm_cell/weights:0'][0]
(这不是特别有效,但变量集通常很小,效率低下并不重要。)
拥有tf.Variable
对象后,您可以使用其load()
方法分配新权重:
cell_0_weights.load(values[0], sess)