我正在使用this回购为我正在做的项目训练Tiny YOLOv3模型。我的问题是,当我使用tf.saved_model.save(model,output)
保存模型,然后用tf.saved_model.load(output)
重新加载模型时,产生的检测结果与我刚刚用model.load_weights(weights)
加载权重的情况不同。我尝试在Tensorboard中可视化已保存的模型图,但可视化的模型看起来不像该模型。我已经在仓库中为同一问题打开了一个问题。但是,我想知道,我该如何进行调试并查找问题?关于如何检查和修复模型保存,我还没有在线找到任何资源。感谢您的任何帮助。
用于保存模型并显示检测结果的代码如下,取自仓库:
import time
from absl import app, flags, logging
from absl.flags import FLAGS
import cv2
import numpy as np
import tensorflow as tf
from yolov3_tf2.models import (
YoloV3, YoloV3Tiny
)
from yolov3_tf2.dataset import transform_images
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import nest
from yolov3_tf2.dataset import transform_images, load_tfrecord_dataset
from yolov3_tf2.utils import draw_outputs
flags.DEFINE_string('weights', './checkpoints/yolov3.tf',
'path to weights file')
flags.DEFINE_boolean('tiny', False, 'yolov3 or yolov3-tiny')
flags.DEFINE_string('output', './serving/yolov3/1', 'path to saved_model')
flags.DEFINE_string('classes', './data/coco.names', 'path to classes file')
flags.DEFINE_string('image', './data/girl.png', 'path to input image')
flags.DEFINE_integer('num_classes',1, 'number of classes in the model')
def main(_argv):
if FLAGS.tiny:
yolo = YoloV3Tiny(classes=FLAGS.num_classes)
else:
yolo = YoloV3(classes=FLAGS.num_classes)
yolo.load_weights(FLAGS.weights).expect_partial()
logging.info('weights loaded')
# tf.saved_model.save(yolo, FLAGS.output)
yolo.save(FLAGS.output)
logging.info("model saved to: {}".format(FLAGS.output))
model = tf.saved_model.load(FLAGS.output)
infer = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
logging.info(infer.structured_outputs)
class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
logging.info('classes loaded')
img_raw = tf.image.decode_image(open(FLAGS.image, 'rb').read(), channels=3)
img = tf.expand_dims(img_raw, 0)
img = transform_images(img, 416)
t1 = time.time()
outputs = infer(img)
boxes, scores, classes, nums = outputs["yolo_nms"], outputs[
"yolo_nms_1"], outputs["yolo_nms_2"], outputs["yolo_nms_3"]
t2 = time.time()
logging.info('time: {}'.format(t2 - t1))
logging.info('detections:')
for i in range(nums[0]):
logging.info('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
scores[0][i].numpy(),
boxes[0][i].numpy()))
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass