我正在加载模型
def _load_model(model_filepath):
model_exp = os.path.expanduser(model_filepath)
if os.path.isfile(model_exp):
print("loading model to graph")
with gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
并在以下代码中使用此功能
tf.reset_default_graph()
with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
_load_model(model_filepath=model_path)
test_set = _get_test_data(input_directory)
images, labels = _load_images_and_labels(test_set, image_size=160, batch_size=batch_size,
num_threads=num_threads, num_epochs=1)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
embedding_layer = tf.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
在每次api调用时,我正在重置默认图表并加载需要很长时间的模型。 我只想加载我的模型一次,并在与新图形的会话中使用它 我怎么能做到这一点?
答案 0 :(得分:0)
通常使用tf.train.Saver()
保存并加载模型,请参阅docs。
因此,在训练模型后,您可以执行以下操作:
saver.save(sess_name, "/path/model.ckpt")
当你想加载(“恢复”)时,你会做这样的事情:
saver = tf.train.Saver()
saver.restore(sess_name, "/path/model.ckpt")
正如Jonathan DEKHTIAR已经提到的那样,在提问之前使用搜索是有意义的: Tensorflow: how to save/restore a model?