我运行了包含在TensorFlow中的word2vec演示程序,现在尝试从文件中恢复预训练模型,但它不起作用。
我运行了这个脚本文件: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py
然后我尝试运行这个文件:
#!/usr/bin/env python
import tensorflow as tf
FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"
def main():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(FILENAME_META)
saver.restore(sess, FILENAME_CHECKPOINT)
if __name__ == '__main__':
main()
失败并显示以下错误消息
Traceback (most recent call last):
File "word2vec_restore.py", line 16, in <module>
main()
File "word2vec_restore.py", line 11, in main
saver = tf.train.import_meta_graph(FILENAME_META)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
producer_op_list=producer_op_list)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
op_def = op_dict[node.op]
KeyError: 'Skipgram'
我认为我已经理解了TensorFlow的API文档,并且我实现了上面编写的代码。我是否以错误的方式使用Saver对象?
答案 0 :(得分:3)
我自己解决了这个问题。我想知道关键的&#39; Skipgram&#39;来源,并挖掘源代码。要解决此问题,只需在顶部添加以下内容:
from tensorflow.models.embedding import gen_word2vec
我仍然不明白我在做什么,但也许这是因为有必要加载用C ++编写的相关库。
感谢。
答案 1 :(得分:0)
尝试以下方法:
saver = tf.train.Saver()
with tf.Session() as sess:
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
其中checkpoint_dir
是包含检查点文件的文件夹的路径,而不是元或检查点文件的完整路径。 Tensorflow从指定的文件夹中选择最新的检查点。