仅在张量流中加载模型一次

时间:2018-02-02 09:04:02

标签: python tensorflow image-recognition

我正在加载模型

def _load_model(model_filepath):
   model_exp = os.path.expanduser(model_filepath)
   if os.path.isfile(model_exp):
   print("loading model to graph")
   with gfile.FastGFile(model_exp, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')

并在以下代码中使用此功能

tf.reset_default_graph()
with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
    _load_model(model_filepath=model_path)
    test_set = _get_test_data(input_directory)
    images, labels = _load_images_and_labels(test_set, image_size=160, batch_size=batch_size,                                                               
    num_threads=num_threads, num_epochs=1)
    init_op = tf.group(tf.global_variables_initializer(), 
    tf.local_variables_initializer())   
    sess.run(init_op)
    images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
    embedding_layer = tf.get_default_graph().get_tensor_by_name("embeddings:0")
    phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")

在每次api调用时,我正在重置默认图表并加载需要很长时间的模型。 我只想加载我的模型一次,并在与新图形的会话中使用它 我怎么能做到这一点?

1 个答案:

答案 0 :(得分:0)

通常使用tf.train.Saver()保存并加载模型,请参阅docs

因此,在训练模型后,您可以执行以下操作:

saver.save(sess_name, "/path/model.ckpt")

当你想加载(“恢复”)时,你会做这样的事情:

saver = tf.train.Saver()
saver.restore(sess_name, "/path/model.ckpt")

正如Jonathan DEKHTIAR已经提到的那样,在提问之前使用搜索是有意义的: Tensorflow: how to save/restore a model?