如何加载和使用Seq2Seq模型?

时间:2019-04-27 12:59:18

标签: python tensorflow nltk seq2seq

我已经使用Seq2Seq模型构建了一个基本的聊天机器人。当我在笔记本中按顺序运行代码时,该bot非常有效-即构建模型->训练模型->测试模型。

我现在想在训练后保存模型,加载模型,然后测试模型。

但是,我遇到了问题/正在努力继续前进。

这是我到目前为止所得到的:

保存模型

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'model_final.ckpt')
这似乎工作正常

加载模型

saver = tf.train.import_meta_graph("model_final.ckpt.meta")
graph = tf.get_default_graph()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver.restore(sess, "model_final.ckpt")
这似乎工作正常

当我按顺序运行时,下面的代码完成提取输入问题,将其标记化并响应该问题的工作。

prediction_c  = tf.argmax(model_c, 2)
result_c = sess_c.run(prediction_c,
                  feed_dict={enc_input_c: input_batch_c,
                             dec_input_c: output_batch_c,
                             targets_c: target_batch_c})

一旦我加载了Seq2Seq模型,就不确定诸如model_c,input_c之类的变量如何获取值/进行初始化。

对于这个问题的基本本质,或者我试图实现的目标没有道理,我深表歉意。我刚刚开始使用张量。

1 个答案:

答案 0 :(得分:0)

您调查过吗?

检查第76-95行以获取恢复代码:https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq_restore.py

代码使用model.save和model.load分别保存和加载模型

要还原的模型是:https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py