保存/恢复dynamic_RNN模型时出现TensorFlow错误

时间:2017-09-05 12:00:37

标签: python tensorflow

如果模型是CNN,我可以保存和恢复模型,但我无法恢复RNN。

我这样做了RNN网络。 我想保存训练有素的体重和偏见或模特。而且我想在没有训练的情况下预测。以下是main.py

#main.py
tf_x = tf.placeholder(tf.float32, [None, seq_length, data_dim], name='tf_x')
tf_y = tf.placeholder(tf.int32, [None, output_dim], name='tf_y')

rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn( rnn_cell,
                                         tf_x,
                                         initial_state=None,
                                         dtype=tf.float32,
                                         time_major=False )

output = tf.layers.dense(outputs[:, -1, :], output_dim, name='dense_output')
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy( labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]

with tf.Session as sess:
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
    sess.run(init_op)     # initialize var in graph

    ...(training)
    saver = tf.train.Saver()
    save_path = saver.save(sess, "Save data/RNN-model")
    saver.export_meta_graph(filename="Save Data/RNN-model.meta", as_text=True)

并在“run.py”中我尝试加载该数据。

#run.py 
...(same as main.py)
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('Save data/')
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver = tf.train.import_meta_graph("Save data/RNN-model.meta")
    ... (prediction)

结果是..

tensorflow.python.framework.errors_impl.NotFoundError: Key dense/bias not found in checkpoint

您认为这是什么问题?

0 个答案:

没有答案