如何恢复张量流模型?

时间:2016-07-18 03:46:29

标签: python tensorflow word2vec

我正在尝试使用.ckpt文件恢复模型,我通过在word2vec_optimized.py中运行tensorflow/models/embedding来获取该文件。我不知道如何恢复变量以便我可以加载模型并使用它,因为所有的tf变量都被封装并在tensorflow/models/embedding/word2vec_optimized.py的类中初始化。任何帮助,将不胜感激。

此外,如果我“恢复”创建的.ckpt,我现在有一个Wor2Vec实例,或者当我使用.ckpt恢复模型时我实际获得了什么?

1 个答案:

答案 0 :(得分:1)

当您在保护程序上调用保存功能时,您将传递用于训练模型的tf.Session。这包含对包含所有变量的图的引用。不要将python变量与tensorflow变量混淆。即使你不再在python中有一个指向你创建的张量流变量的变量,它仍然存在,如果它是计算图的一部分。创建模型后,请尝试运行以下代码。

for v in tf.all_variables():
    print(v.name)

这将打印出您创建的每个变量的名称。默认情况下,保存程序会保存所有这些内容。只要变量在恢复时具有相同的名称,它们的创建位置无关紧要。只需确保在将所有变量添加到模型后执行还原。为变量提供初始化程序时,只有在调用sess.run(tf.initialize_all_variables())时才会运行初始化。如果只是恢复值,则无需调用此方法。我经常使用以下代码。

sess = tf.Session()
saver = tf.train.Saver()
if 'restore' in sys.argv:
    saver.restore(sess, '/media/chase/98d61322-9ea7-473e-b835-8739c77d1e1e/model.chk')
else:
    sess.run(tf.initialize_all_variables())

当我使用在其中创建变量的thensorflow RNN类时,此代码正常工作。