有什么方法只能使用tensorflow.estimator.train_and_evaluate()保存最佳模型吗?

时间:2019-03-05 14:46:03

标签: python tensorflow machine-learning computer-vision object-detection-api

我尝试使用tf.estimator.train_and_evaluate()方法从已经带有.config文件的检查点重新训练TF对象检测API模型,就像在models / research / object_detection / model_main.py中那样。它每N步或每N秒保存检查点。

但是我只想保存一种最好的模型,例如在Keras中。 有什么方法可以使用TF对象检测API模型吗?也许是tf.Estimator.train的某些选项/回调,还是将检测API与Keras结合使用的某种方式?

3 个答案:

答案 0 :(得分:3)

您可以尝试使用BestExporter。据我所知,这是您要做的唯一选择。

exporter = tf.estimator.BestExporter(
      compare_fn=_loss_smaller,
      exports_to_keep=5)

eval_spec = tf.estimator.EvalSpec(
    input_fn,
    steps,
    exporters)

https://www.tensorflow.org/api_docs/python/tf/estimator/BestExporter

答案 1 :(得分:2)

我一直在使用https://github.com/bluecamel/best_checkpoint_copier,对我来说效果很好。

示例:

best_copier = BestCheckpointCopier(
   name='best', # directory within model directory to copy checkpoints to
   checkpoints_to_keep=10, # number of checkpoints to keep
   score_metric='metrics/total_loss', # metric to use to determine "best"
   compare_fn=lambda x,y: x.score < y.score, # comparison function used to determine "best" checkpoint (x is the current checkpoint; y is the previously copied checkpoint with the highest/worst score)
   sort_key_fn=lambda x: x.score,
   sort_reverse=False) # sort order when discarding excess checkpoints

将其传递给您的eval_spec:

eval_spec = tf.estimator.EvalSpec(
   ...
   exporters=best_copier,
   ...)

答案 2 :(得分:1)

如果您正在使用张量流/模型的模型存储库进行训练。 可以修改models/research/object_detection/model_lib.py文件create_train_and_eval_specs的功能以包括最佳的导出程序:

final_exporter = tf.estimator.FinalExporter(
    name=final_exporter_name, serving_input_receiver_fn=predict_input_fn)

best_exporter = tf.estimator.BestExporter(
    name="best_exporter",
    serving_input_receiver_fn=predict_input_fn,
    event_file_pattern='eval_eval/*.tfevents.*',
    exports_to_keep=5)
exporters = [final_exporter, best_exporter]

train_spec = tf.estimator.TrainSpec(
    input_fn=train_input_fn, max_steps=train_steps)

eval_specs = [
    tf.estimator.EvalSpec(
        name=eval_spec_name,
        input_fn=eval_input_fn,
        steps=eval_steps,
        exporters=exporters)
]