我已经使用Seq2Seq模型构建了一个基本的聊天机器人。当我在笔记本中按顺序运行代码时,该bot非常有效-即构建模型->训练模型->测试模型。
我现在想在训练后保存模型,加载模型,然后测试模型。
但是,我遇到了问题/正在努力继续前进。
这是我到目前为止所得到的:
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'model_final.ckpt')
这似乎工作正常
saver = tf.train.import_meta_graph("model_final.ckpt.meta")
graph = tf.get_default_graph()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver.restore(sess, "model_final.ckpt")
这似乎工作正常
当我按顺序运行时,下面的代码完成提取输入问题,将其标记化并响应该问题的工作。
prediction_c = tf.argmax(model_c, 2)
result_c = sess_c.run(prediction_c,
feed_dict={enc_input_c: input_batch_c,
dec_input_c: output_batch_c,
targets_c: target_batch_c})
一旦我加载了Seq2Seq模型,就不确定诸如model_c,input_c之类的变量如何获取值/进行初始化。
对于这个问题的基本本质,或者我试图实现的目标没有道理,我深表歉意。我刚刚开始使用张量。
答案 0 :(得分:0)
您调查过吗?
检查第76-95行以获取恢复代码:https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq_restore.py
代码使用model.save和model.load分别保存和加载模型
要还原的模型是:https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py