Tensorflow:BasicLSTM和dynamicRNN - 尝试重用模型

时间:2017-06-27 05:18:37

标签: python-2.7 tensorflow lstm recurrent-neural-network

然而,我试图将训练模型(带有RNN的LSTM)重用于不同的数据集 当我使用下面的代码时,我遇到了一些Valueerror:

ValueError: Variable LSTM/rnn/basic_lstm_cell/weights does not exist, or was not created with tf.get_variable()

任何人都可以告诉我这个吗?谢谢!

在主文件中,代码如下:

# build a LSTM network
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim, state_is_tuple=True, activation=tf.tanh)

with tf.variable_scope("LSTM") as vs:
outputs, _states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
Y_pred = tf.contrib.layers.fully_connected(outputs[:, -1], output_dim, activation_fn=None)


#To reuse Variables
lstm_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,   scope=vs.name)

#Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

with tf.Session() as sess:
init = tf.global_variables_initializer()

##### Add ops to save and restore all the variables.
saver = tf.train.Saver()

在另一个文件中,我想重复使用经过培训的模型

saver = tf.train.import_meta_graph("model.ckpt.meta")
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim, state_is_tuple=True, activation=tf.tanh,reuse=True)

with tf.variable_scope("LSTM", reuse=True) as vs:
outputs, _states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
Y_pred = tf.contrib.layers.fully_connected(outputs[:, -1], output_dim, activation_fn=None)
lstm_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
tf.initialize_variables(lstm_variables)

0 个答案:

没有答案