如何继续使用Tensorflow对象检测API训练对象检测模型?

时间:2018-11-01 15:25:11

标签: tensorflow machine-learning google-cloud-ml object-detection-api

我正在使用Tensorflow Object Detection API来通过转移学习训练对象检测模型。具体来说,我使用的是ssd_mobilenet_v1_fpn_coco from the model zoo,使用的是sample pipeline provided,当然是将占位符替换为指向我的培训,评估tfrecords和标签的实际链接。

我能够使用上述管道成功地在约5000张图像(以及相应的边界框)上训练模型(如果愿意的话,我主要是在TPU上使用Google的ML引擎)。

现在,我准备了大约2000张图像,并希望继续用这些新图像训练模型,而无需从头开始(训练初始模型花了大约6个小时的TPU时间)。我该怎么办?

2 个答案:

答案 0 :(得分:1)

您有两个选择,都需要更改新数据集的$corner-bevel: 20; .corner-bottom-left-bevel { width: 80px; height: 0; border-radius: 2px; border-style: solid; border-color: $green-color transparent transparent transparent; border-width: #{$corner-bevel}px 0 0 #{$corner-bevel}px; } 的{​​{1}}:

  1. 在训练配置中指定要微调的检查点时,请指定训练模型的检查点 input_path
  2. 只需继续使用与先前模型相同的配置train_input_reader即可(除了 train_config{ fine_tune_checkpoint: <path_to_your_checkpoint> fine_tune_checkpoint_type: true load_all_detection_checkpoint_vars: true } )。这样,API将创建一个图形,直到检查train_input_reader中是否已经存在检查点并适合该图形。如果是这样,它将恢复并继续训练。

答案 1 :(得分:0)

我还没有在新的数据集中重新训练对象检测模型,但是看起来 在配置文件中增加训练步骤train_config.num_steps的数量并在tfrecord文件中添加图像就足够了。