在张量流中恢复预训练模型的麻烦

时间:2016-08-26 06:18:08

标签: python tensorflow word2vec

我运行了包含在TensorFlow中的word2vec演示程序,现在尝试从文件中恢复预训练模型,但它不起作用。

我运行了这个脚本文件: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py

然后我尝试运行这个文件:

#!/usr/bin/env python

import tensorflow as tf

FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"


def main():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(FILENAME_META)
        saver.restore(sess, FILENAME_CHECKPOINT)


if __name__ == '__main__':
    main()

失败并显示以下错误消息

Traceback (most recent call last):
  File "word2vec_restore.py", line 16, in <module>
    main()
  File "word2vec_restore.py", line 11, in main
    saver = tf.train.import_meta_graph(FILENAME_META)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
    return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
    producer_op_list=producer_op_list)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
    op_def = op_dict[node.op]
KeyError: 'Skipgram'

我认为我已经理解了TensorFlow的API文档,并且我实现了上面编写的代码。我是否以错误的方式使用Saver对象?

2 个答案:

答案 0 :(得分:3)

我自己解决了这个问题。我想知道关键的&#39; Skipgram&#39;来源,并挖掘源代码。要解决此问题,只需在顶部添加以下内容:

from tensorflow.models.embedding import gen_word2vec

我仍然不明白我在做什么,但也许这是因为有必要加载用C ++编写的相关库。

感谢。

答案 1 :(得分:0)

尝试以下方法:

saver = tf.train.Saver()
with tf.Session() as sess:
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)

其中checkpoint_dir是包含检查点文件的文件夹的路径,而不是元或检查点文件的完整路径。 Tensorflow从指定的文件夹中选择最新的检查点。