在tensorflow .ckpt文件中使用pretraing模型

时间:2017-03-20 00:29:36

标签: python-3.x tensorflow deep-learning

我有一个ckpt文件。我只是想得到cnn的权重 我是从ckpt检查点文件中训练过的。 inception_resnet_v2_2016_08_30

import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "inception_resnet_v2_2016_08_30.ckpt")
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
with session.Session() as sess:
var_list = {}
reader =pywrap_tensorflow.NewCheckpointReader("./inception_resnet_v2_2016_08_30.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    try:
       tensor = sess.graph.get_tensor_by_name(key + ":0")
    except KeyError:
            continue
    var_list[key] = tensor
    saver = saver_lib.Saver(var_list=var_list)
    saver.restore(sess, input_checkpoint)
    if initializer_nodes:
       sess.run(initializer_nodes)

1 个答案:

答案 0 :(得分:1)

只有在您已经构建了要恢复检查点的图形结构(包括一组tf.train.Saver.restore()个对象)之后,tf.Variable方法才有效。您有(至少)两个选项来解决此问题:

  1. 使用tf.train.NewCheckpointReader("inception_resnet_v2_2016_08_30.ckpt")打开检查点文件。您可以在返回的对象上调用get_tensor()方法以按名称查找已保存的变量,或使用get_variable_to_shape_map()方法获取可用变量的列表。

  2. 如果您有一个load a MetaGraph用于检查点模式,其中包括图表结构以及从该图表结构到检查点中的变量的映射。