如何在Tensorflow中还原变量并在为lstm创建新图时使用它们?

时间:2019-01-09 22:57:54

标签: tensorflow lstm

我正在预训练其中包含一些lstm层的图。我需要保存lstm变量,并在另一个结构完全不同的图中使用它。经过研究,我发现可以使用tf.train.Saver仅保存lstm单元权重。现在我的问题是,如何使用saver.restore还原这些权重,更重要的是,如何在新的lstm图中使用它们?

这是我所做的部分的代码:

def q_network(X_state, name):
with tf.variable_scope(name) as scope:
    #defining the network
    with tf.variable_scope('lstm') as vs:
        input=tf.unstack(X_state ,time_steps,1)
        lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
        out_lstm,_=rnn.static_rnn(lstm_layer,input,dtype="float32")
        out = out_lstm[-1]
    trainable_vars_lstm_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
    fc1 = tf.layers.dense(out, 80, activation = tf.nn.relu,kernel_initializer=initializer)
    fc2 = tf.layers.dense(fc1, 60, activation = tf.nn.relu,kernel_initializer=initializer)
    fc3 = tf.layers.dense(fc2, 40,kernel_initializer=initializer, activation = tf.nn.relu)
    logits_ = tf.layers.dense(fc3, n_outputs,kernel_initializer=initializer)
    trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope.name)
    trainable_vars_by_name = {var.name[len(scope.name):]: var for var in trainable_vars}
    variables_ = tf.trainable_variables()
return logits_, trainable_vars_by_name, variables_, trainable_vars_lstm_

X_state = tf.placeholder(tf.float32, shape=[None, time_steps ,state_size])
X_action = tf.placeholder(tf.int32, shape=[None])
logits, online_vars, variables, trainable_vars_lstm = q_network(X_state, name="q_networks/online")


init = tf.global_variables_initializer()
saver = tf.train.Saver(trainable_vars_lstm)

0 个答案:

没有答案