有关对象检测API检查点的查询

时间:2018-08-21 20:23:12

标签: python tensorflow image-processing object-detection object-detection-api

我有一些有关Tensorflow对象检测API的查询。

  1. 在训练时,仅存储前5个检查点。我想存储更多,比如说前面的10个检查要点。如何才能做到这一点? (我认为它应该是 object_detection / protos train.proto的参数之一。)

  2. 默认情况下,检查点每10分钟(600秒)存储一次。要更改此频率,我相信这是必须更改的两个参数之一,请确认它是哪个:

    来自learning.py /home/user/tensorflow-gpu/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim

    save_summaries_secs=600

    save_interval_secs=600

  3. 在训练模型(ssd_mobilenet_v2_coco_2018_03_29)的同时,我也同时运行评估。评估图中显示的最新检查点始终滞后于保存在 object_detection / training 文件夹中的最新检查点。例如,在以下情况下,图形上显示的最新检查点为29.437k,而模型已经过训练,直到检查点为32.891k(并保存在 training 文件夹中)。这种滞后(滞后20分钟)的原因是什么?为什么一步(十分钟)还不足以对经过训练的模型进行评估?

2 个答案:

答案 0 :(得分:1)

这适用于想要配置支持TensorFlow 2的更新的对象检测API的任何人

  1. 要保存前10个检查点,请打开model_lib.py并将关键字参数max_to_keep = 10传递给每个 tf.train.Saver 函数
  2. 要将频率从600秒更改为3600秒(1小时), 打开model_main.py并在主函数中找到包含 tf.estimator.RunConfig 的行。
    将关键字参数save_checkpoints_secs = 3600传递给 tf.estimator.RunConfig 类。

这是在model_main.py中配置检查点保存频率后的代码段:

def main(unused_argv):
      flags.mark_flag_as_required('model_dir')   
      flags.mark_flag_as_required('pipeline_config_path')   
      config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, save_checkpoints_secs=3600)

请注意,参数中有一个参数keep_checkpoint_max tf.estimator.RunConfig 类,但设置它不会影响我保存的检查点数量。

答案 1 :(得分:0)

我认为此帖子应该可以正常运行,我认为可以更改keep_checkpoint_every_n_hours max_to_keep

How to store best models checkpoints, not only newest 5, in Tensorflow Object Detection API?

您还可以参考官方文档 https://www.tensorflow.org/api_docs/python/tf/train/Saver