如何在Tensorflow Object Detection API中存储最佳模型检查点,而不仅仅是最新的5个检查点?

时间:2018-04-06 05:26:02

标签: tensorflow object-detection

我在WIDER FACE数据集上训练MobileNet,我遇到了无法解决的问题。 TF对象检测API仅存储train目录中的最后5个检查点,但我想要做的是保存相对于mAP度量的最佳模型(或者至少在train目录中保留更多模型删除)。 例如,今天我在第二天的训练后看了Tensorboard,我发现过夜模型已经过度拟合,我无法恢复最佳检查点,因为它已经被删除了。

编辑:我只使用Tensorflow Object Detection API,默认情况下会保存我指向的火车目录中的最后5个检查点。我寻找一些配置参数或任何会改变这种行为的东西。

有没有人在code / config param中设置/解决方法?看起来我错过了一些东西,显而易见的是,实际上重要的是最好的模型,而不是最新的模型(可以过度拟合)。

谢谢!

5 个答案:

答案 0 :(得分:2)

您可以修改(在您的分叉中硬编码或打开拉取请求并向protos添加选项)传递给tf.train.Saver的参数:

https://github.com/tensorflow/models/blob/master/research/object_detection/legacy/trainer.py#L376-L377

您可能想要设置:

  • max_to_keep:要保留的最近检查点的最大数量。默认为5。
  • keep_checkpoint_every_n_hours:保持检查点的频率。默认为10,000小时。

答案 1 :(得分:2)

您可以更改配置。

在run_config.py中

class RunConfig(object):
  """This class specifies the configurations for an `Estimator` run."""

  def __init__(self,
           model_dir=None,
           tf_random_seed=None,
           save_summary_steps=100,
           save_checkpoints_steps=_USE_DEFAULT,
           save_checkpoints_secs=_USE_DEFAULT,
           session_config=None,
           keep_checkpoint_max=10,
           keep_checkpoint_every_n_hours=10000,
           log_step_count_steps=100,
           train_distribute=None,
           device_fn=None,
           protocol=None,
           eval_distribute=None,
           experimental_distribute=None):

答案 2 :(得分:1)

您可能对解决最新/最佳检查点问题的Tf github thread感兴趣。用户在tf.Saver附近开发了自己的包装器chekmate,以跟踪最佳检查点。

答案 3 :(得分:0)

您可以跟进this PR。此处,最佳检查点将保存在检查点目录中,该目录为名为“最佳”的子目录

您只需要在../ object_detection / eval_util.py 中集成 best_saver()和( _run_checkpoint_once()中的方法调用) >

此外,它还将为all_evaluation_metrices创建一个json。

答案 4 :(得分:0)

为了保存更多的检查点,您可以编写一个简单的python脚本,将检查点及时存储到特定的位置。

import os
import shutil
import time

while True:
    
    training_file = '/home/vignesh/training' # path of your train directory
    archive_file = 'home/vignesh/training/archive' #path of the directory where you want to save your checkpoints
    files_to_save = []

    for files in os.listdir(training_file):
        
        if files.rsplit('.')[0]=='model':
            
            files_to_save.append(files)

    for files in files_to_save:
        if files in os.listdir(archive_file):
            pass
        else:
            shutil.copy2(training_file+'/'+files,archive_file)
    time.sleep(600) # This will make the script run for every 600 seconds, modify it for your need