Tensorflow:如何加载SavedModel对象并继续训练?

时间:2018-05-22 01:21:38

标签: tensorflow

我使用以下代码保存了模型:

builder = tf.saved_model.builder.SavedModelBuilder(save_dir)

tensor_info_input_image = tf.saved_model.utils.build_tensor_info(input_image)
tensor_info_logits = tf.saved_model.utils.build_tensor_info(logits)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'net_input': tensor_info_input_image},
      outputs={'logits': tensor_info_logits},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

我以为我可以通过运行以下代码来加载它:

meta_graph_def = tf.saved_model.loader.load(sess,
                                [tf.saved_model.tag_constants.SERVING],
                                model_path)
graph = tf.get_default_graph()
net_input = graph.get_tensor_by_name(args.load_net_input_name)
net_output = graph.get_tensor_by_name(args.load_net_output_name)
network = graph.get_tensor_by_name(args.load_logits_name)
loss = graph.get_tensor_by_name(args.load_loss_name)
opt = graph.get_operation_by_name(args.load_optimizer_name)

但是,当我运行它时,loss的值似乎恢复到未经训练的模型。我的期望是能够在我离开的地方继续。此外,是否可以将培训分为两个阶段?

  • 第一次会议:使用原始数据进行培训。保存模型。
  • 第二个会话:加载模型。用增强数据训练。

谢谢。

0 个答案:

没有答案
相关问题