然而,我试图将训练模型(带有RNN的LSTM)重用于不同的数据集 当我使用下面的代码时,我遇到了一些Valueerror:
ValueError: Variable LSTM/rnn/basic_lstm_cell/weights does not exist, or was not created with tf.get_variable()
任何人都可以告诉我这个吗?谢谢!
在主文件中,代码如下:
# build a LSTM network
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim, state_is_tuple=True, activation=tf.tanh)
with tf.variable_scope("LSTM") as vs:
outputs, _states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
Y_pred = tf.contrib.layers.fully_connected(outputs[:, -1], output_dim, activation_fn=None)
#To reuse Variables
lstm_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
#Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
with tf.Session() as sess:
init = tf.global_variables_initializer()
##### Add ops to save and restore all the variables.
saver = tf.train.Saver()
在另一个文件中,我想重复使用经过培训的模型
saver = tf.train.import_meta_graph("model.ckpt.meta")
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim, state_is_tuple=True, activation=tf.tanh,reuse=True)
with tf.variable_scope("LSTM", reuse=True) as vs:
outputs, _states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
Y_pred = tf.contrib.layers.fully_connected(outputs[:, -1], output_dim, activation_fn=None)
lstm_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
tf.initialize_variables(lstm_variables)