在Tensorflow中调试自定义模型保存

时间:2020-06-12 11:03:05

标签: python tensorflow

我正在使用this回购为我正在做的项目训练Tiny YOLOv3模型。我的问题是,当我使用tf.saved_model.save(model,output)保存模型,然后用tf.saved_model.load(output)重新加载模型时,产生的检测结果与我刚刚用model.load_weights(weights)加载权重的情况不同。我尝试在Tensorboard中可视化已保存的模型图,但可视化的模型看起来不像该模型。我已经在仓库中为同一问题打开了一个问题。但是,我想知道,我该如何进行调试并查找问题?关于如何检查和修复模型保存,我还没有在线找到任何资源。感谢您的任何帮助。

用于保存模型并显示检测结果的代码如下,取自仓库:

import time
from absl import app, flags, logging
from absl.flags import FLAGS
import cv2
import numpy as np
import tensorflow as tf
from yolov3_tf2.models import (
    YoloV3, YoloV3Tiny
)
from yolov3_tf2.dataset import transform_images

from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import nest
from yolov3_tf2.dataset import transform_images, load_tfrecord_dataset
from yolov3_tf2.utils import draw_outputs

flags.DEFINE_string('weights', './checkpoints/yolov3.tf',
                    'path to weights file')
flags.DEFINE_boolean('tiny', False, 'yolov3 or yolov3-tiny')
flags.DEFINE_string('output', './serving/yolov3/1', 'path to saved_model')
flags.DEFINE_string('classes', './data/coco.names', 'path to classes file')
flags.DEFINE_string('image', './data/girl.png', 'path to input image')
flags.DEFINE_integer('num_classes',1, 'number of classes in the model')

def main(_argv):
    if FLAGS.tiny:
        yolo = YoloV3Tiny(classes=FLAGS.num_classes)
    else:
        yolo = YoloV3(classes=FLAGS.num_classes)

    yolo.load_weights(FLAGS.weights).expect_partial()
    logging.info('weights loaded')

    # tf.saved_model.save(yolo, FLAGS.output)
    yolo.save(FLAGS.output)
    logging.info("model saved to: {}".format(FLAGS.output))

    model = tf.saved_model.load(FLAGS.output)
    infer = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    logging.info(infer.structured_outputs)

    class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
    logging.info('classes loaded')

    img_raw = tf.image.decode_image(open(FLAGS.image, 'rb').read(), channels=3)
    img = tf.expand_dims(img_raw, 0)
    img = transform_images(img, 416)

    t1 = time.time()
    outputs = infer(img)
    boxes, scores, classes, nums = outputs["yolo_nms"], outputs[
        "yolo_nms_1"], outputs["yolo_nms_2"], outputs["yolo_nms_3"]
    t2 = time.time()
    logging.info('time: {}'.format(t2 - t1))

    logging.info('detections:')
    for i in range(nums[0]):
        logging.info('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
                                           scores[0][i].numpy(),
                                           boxes[0][i].numpy()))

if __name__ == '__main__':
    try:
        app.run(main)
    except SystemExit:
        pass

0 个答案:

没有答案