如何加载由export_inference_graph.py保存的训练模型?

时间:2020-04-22 21:29:47

标签: tensorflow conv-neural-network object-detection-api transfer-learning tensorflow-model-garden

我正在举一个使用tensorflow的1.15.0对象检测API的示例。 该教程在以下几个方面进行了明确说明:

  • 如何下载模型
  • 如何使用.xml文件加载自定义数据库,如何使用它们创建.cvs文件,然后.record文件
  • 如何配置培训管道
  • 如何获取张量板图
  • 如何训练净节省检查点(使用model_main.py)
  • 如何导出(保存)模型(使用export_inference_graph.py)

但是,我无法完成的工作是加载保存的模型以使用它。 我尝试使用tf.saved_model.loader.load(sess, flags, export_dir,但得到

INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.

export_dir中给出的文件夹具有以下结构:

+dir
   +saved_model
      -saved_model.pb
   -model.ckpt.data-00000-of-00001
   -model.ckpt.index
   -checkpoint
   -frozen_inference_graph.pb
   -model.ckpt.meta
   -pipeline.config

我在这里的最终目标是使用相机捕获图像,并将其馈送到网络以进行实时物体检测。\ 作为一个介于两者之间的步骤,现在我只希望能够提供一张图片并获得输出。我可以训练网络,但是现在我不能使用它。

谢谢。

1 个答案:

答案 0 :(得分:2)

我发现an example on how to download a model让我经历了它。\ 由于示例中下载的文件的文件夹格式与我在代码中得到的格式相同,因此只需要对其进行调整即可。

下载模型的原始功能是

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

然后我使用该函数来创建这个新的

def load_local_model(model_path):
  model_dir = pathlib.Path(model_path)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

起初这没有用,因为tf.saved_model.load期望有3个参数,但这是通过在同一示例中导入两个 import 块解决的,我不知道该怎么做诀窍和原因(我将在得到答案时对其进行编辑),但目前此代码有效,示例使我们可以做更多的事情。

导入块如下

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

编辑 要使其正常工作,真正需要的是以下方框。

import os
import pathlib


if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.

%%bash 
cd models/research
pip install .

否则此导入块将不起作用

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util