如果模型是CNN,我可以保存和恢复模型,但我无法恢复RNN。
我这样做了RNN网络。 我想保存训练有素的体重和偏见或模特。而且我想在没有训练的情况下预测。以下是main.py
#main.py
tf_x = tf.placeholder(tf.float32, [None, seq_length, data_dim], name='tf_x')
tf_y = tf.placeholder(tf.int32, [None, output_dim], name='tf_y')
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn( rnn_cell,
tf_x,
initial_state=None,
dtype=tf.float32,
time_major=False )
output = tf.layers.dense(outputs[:, -1, :], output_dim, name='dense_output')
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy( labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]
with tf.Session as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
sess.run(init_op) # initialize var in graph
...(training)
saver = tf.train.Saver()
save_path = saver.save(sess, "Save data/RNN-model")
saver.export_meta_graph(filename="Save Data/RNN-model.meta", as_text=True)
并在“run.py”中我尝试加载该数据。
#run.py
...(same as main.py)
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('Save data/')
saver.restore(sess, ckpt.model_checkpoint_path)
saver = tf.train.import_meta_graph("Save data/RNN-model.meta")
... (prediction)
结果是..
tensorflow.python.framework.errors_impl.NotFoundError: Key dense/bias not found in checkpoint
您认为这是什么问题?