如何使用训练和存储的张量流模型进行预测

时间:2019-01-30 23:25:33

标签: python tensorflow word2vec

我有一个训练有素的模型(特别是tensorflow word2vec https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb)。我可以很好地还原现有模型:

model1 = tf.train.import_meta_graph("models/model.meta")
model1.restore(sess, tf.train.latest_checkpoint("model/"))

但是我不知道如何使用新加载(训练有素)的模型进行预测。如何使用还原的模型进行预测?

编辑:

官方tensorflow回购https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py中的

模型代码

1 个答案:

答案 0 :(得分:1)

根据您如何加载检查点,我认为这应该是将其用于推理的最佳方法。

加载占位符:

input = tf.get_default_graph().get_tensor_by_name("Placeholders/placeholder_name:0")
....

加载用于执行预测的操作:

prediction = tf.get_default_graph().get_tensor_by_name("SomewhereInsideGraph/prediction_op_name:0")

创建会话,执行预测操作,并将数据输入占位符。

sess = tf.Session()
sess.run(prediction, feed_dict={input:input_data})

另一方面,我更喜欢做的是始终在类的构造函数中创建整个模型。然后,我将执行以下操作:

tf.reset_default_graph()
model = ModelClass()
loader = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
loader.restore(sess, path_to_checkpoint_dir)

由于您希望从训练有素的word2vec模型中将嵌入内容加载到另一个模型中,因此您应该执行以下操作:

embeddings_new_model = tf.Variable(...,name="embeddings")
embedding_saver = tf.train.Saver({"embeddings_word2vec": embeddings_new_model})
with tf.Session() as sess:
    embedding_saver.restore(sess, "word2vec_model_path")

假设word2vec模型中的embeddings变量名为embeddings_word2vec