我正在预训练其中包含一些lstm层的图。我需要保存lstm变量,并在另一个结构完全不同的图中使用它。经过研究,我发现可以使用tf.train.Saver仅保存lstm单元权重。现在我的问题是,如何使用saver.restore还原这些权重,更重要的是,如何在新的lstm图中使用它们?
这是我所做的部分的代码:
def q_network(X_state, name):
with tf.variable_scope(name) as scope:
#defining the network
with tf.variable_scope('lstm') as vs:
input=tf.unstack(X_state ,time_steps,1)
lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
out_lstm,_=rnn.static_rnn(lstm_layer,input,dtype="float32")
out = out_lstm[-1]
trainable_vars_lstm_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
fc1 = tf.layers.dense(out, 80, activation = tf.nn.relu,kernel_initializer=initializer)
fc2 = tf.layers.dense(fc1, 60, activation = tf.nn.relu,kernel_initializer=initializer)
fc3 = tf.layers.dense(fc2, 40,kernel_initializer=initializer, activation = tf.nn.relu)
logits_ = tf.layers.dense(fc3, n_outputs,kernel_initializer=initializer)
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope.name)
trainable_vars_by_name = {var.name[len(scope.name):]: var for var in trainable_vars}
variables_ = tf.trainable_variables()
return logits_, trainable_vars_by_name, variables_, trainable_vars_lstm_
X_state = tf.placeholder(tf.float32, shape=[None, time_steps ,state_size])
X_action = tf.placeholder(tf.int32, shape=[None])
logits, online_vars, variables, trainable_vars_lstm = q_network(X_state, name="q_networks/online")
init = tf.global_variables_initializer()
saver = tf.train.Saver(trainable_vars_lstm)