Tensorflow:仅在训练期间将错误最小化的情况下,如何才能保存检查点?

时间:2018-12-11 06:24:45

标签: tensorflow

我正在运行一个tensorflow程序,我想存储最佳模型供以后使用。我正在使用 estimator tf.contrib.tpu.TPUEstimator模块,该模块带有run_config参数,在其中我设置了save_checkpoints_secs=20*60)进行训练。

estimator.train以train_input_fn和num_train_steps作为参数。 例如:estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

我不想存储每隔n秒钟就保存检查点的方法,而是要存储训练时误差最小的最佳模型。

欢迎任何帮助。

1 个答案:

答案 0 :(得分:1)

tf.estimator.BestExporter似乎正是您要寻找的。根据{{​​3}},它指出:

  

每次创建新模型时,此类都会执行模型导出   比任何现有模型都要好。

  estimator = tf.estimator.DNNClassifier(
      config=tf.estimator.RunConfig(
          model_dir='/my_model', save_summary_steps=100),
      feature_columns=[categorial_feature_a_emb, ...],
      hidden_units=[1024, 512, 256])

  serving_feature_spec = tf.feature_column.make_parse_example_spec(
      categorial_feature_a_emb)
  serving_input_receiver_fn = (
      tf.estimator.export.build_parsing_serving_input_receiver_fn(
      serving_feature_spec))

  exporter = tf.estimator.BestExporter(
      name="best_exporter",
      serving_input_receiver_fn=serving_input_receiver_fn,
      exports_to_keep=5)

  train_spec = tf.estimator.TrainSpec(...)

  eval_spec = [tf.estimator.EvalSpec(
    input_fn=eval_input_fn,
    steps=100,
    exporters=exporter,
    start_delay_secs=0,
    throttle_secs=5)]