使用tf.estimator.Estimator加载检查点和微调

时间:2017-09-26 10:27:42

标签: tensorflow

我们正在尝试将旧的培训代码转换为更符合tf.estimator.Estimator的代码。 在初始代码中,我们为目标数据集微调原始模型。在训练开始之前,只需使用 variables_to_restore init_fn MonitoredTrainingSession 的组合,从检查点加载一些图层。 如何用tf.estimator.Estimator方法实现这种重量加载?

2 个答案:

答案 0 :(得分:3)

你有两个选择,第一个更简单:

1-在var passcodeLockPresenter: PasscodeLockPresenter = { let configuration = PasscodeLockConfiguration() let presenter = PasscodeLockPresenter(mainWindow: self.view.window, configuration: configuration) return presenter }() passcodeLockPresenter.presentPasscodeLock()

中使用tf.train.init_from_checkpoint

2- model_fn会返回model_fn。您可以通过EstimatorSpec设置脚手架。

答案 1 :(得分:1)

import tensorflow as tf    

def model_fn():
  # your model defintion here
  # ...

# specify your saved checkpoint path
checkpoint_path = "model.ckpt"

ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=checkpoint_path)
est = tf.estimator.Estimator(model_fn=model_fn, warm_start_from=ws)