加载facenet的tensorflow .pb文件

时间:2017-10-24 18:32:29

标签: python machine-learning tensorflow

我从here下载了张量流模型。现在我试图在python中加载它。根据{{​​3}},.pb文件可以像这样加载:

import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile

fname = "20170512-110547.pb"

with tf.Session() as persisted_sess:
    print("load graph")
    with gfile.FastGFile(fname,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    print("map variables")
    persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
    tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
    try:
        saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!
    except:pass
    print("load data")
    saver.restore(persisted_sess, "checkpoint.data")  # now OK
    print(persisted_result.eval())
    print("DONE")

但是我收到了错误

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-2-9995bff0b3cd> in <module>()
     12         graph_def.ParseFromString(f.read())
     13         persisted_sess.graph.as_default()
---> 14         tf.import_graph_def(graph_def, name='')
     15     print("map variables")
     16     persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")

/net/hciserver03/storage/oblum/venvs/compvisgpu01/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    256     for node in graph_def.node:
    257       # Set any default attr values that aren't present.
--> 258       op_def = op_dict[node.op]
    259       for attr_def in op_def.attr:
    260         key = attr_def.name

KeyError: u'FIFOQueueV2'

0 个答案:

没有答案