我们正在尝试将旧的培训代码转换为更符合tf.estimator.Estimator的代码。 在初始代码中,我们为目标数据集微调原始模型。在训练开始之前,只需使用 variables_to_restore 和 init_fn 与 MonitoredTrainingSession 的组合,从检查点加载一些图层。 如何用tf.estimator.Estimator方法实现这种重量加载?
答案 0 :(得分:3)
你有两个选择,第一个更简单:
1-在var passcodeLockPresenter: PasscodeLockPresenter = {
let configuration = PasscodeLockConfiguration()
let presenter = PasscodeLockPresenter(mainWindow: self.view.window, configuration: configuration)
return presenter
}()
passcodeLockPresenter.presentPasscodeLock()
tf.train.init_from_checkpoint
2- model_fn
会返回model_fn
。您可以通过EstimatorSpec
设置脚手架。
答案 1 :(得分:1)
import tensorflow as tf
def model_fn():
# your model defintion here
# ...
# specify your saved checkpoint path
checkpoint_path = "model.ckpt"
ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=checkpoint_path)
est = tf.estimator.Estimator(model_fn=model_fn, warm_start_from=ws)