tf.estimator.Estimator.train()是否维护input_fn状态

时间:2017-10-25 06:15:52

标签: tensorflow

我一直在使用自己的Estimator / Experiment代码超过一年,但我想最终跳上Dataset + Estimator的潮流。

我想做以下事情:

for _ in range(N):
  estimator.train(train_input_fn, steps=1000)
  estimator.evaluate(validation_input_fn)

train_input_fn创建tf.data.Dataset永久循环训练集,validation_input_fn创建tf.data.Dataset,执行验证集的一次传递。

train()跨呼叫保持train_input_fn的状态(即,如果引用匹配,则只调用一次)?这是人们如何使用Estimator进行训练循环吗?

2 个答案:

答案 0 :(得分:3)

正如我在上面的评论中提到的,看起来它不会在调用estimator.train()时保存状态。

我要使用的解决方案,可能是预期的方法,是将评估侦听器传递给estimator.train()。例如,

class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
  def __init__(self, estimator, input_fn):
    self.estimator = estimator
    self.input_fn = input_fn

  def after_save(self, session, global_step):
    self.estimator.evaluate(self.input_fn)

estimator.train(
  input_fn=lambda:_train_input_fn(...),
  max_steps=N,
  saving_listeners=[
    EvalCheckpointSaverListener(
      estimator,
      lambda:_eval_input_fn(...), 
    ),
  ],
)

答案 1 :(得分:2)

您现在还可以使用Estimator API中的train_and_evaluate方法。

这是它的工作原理:

estimator = tf.estimator.Estimator(
   model_fn=model_fn,
   model_dir=self.model_dir,
   params=params
)

train_spec = tf.estimator.TrainSpec(input_fn, max_steps=N)

eval_spec = tf.estimator.EvalSpec(
   validation_input_fn,
   steps=None,
   start_delay_secs=120, # start evaluating 120 seconds after beginning of training
   throttle_secs=600 # evaluate every 600 seconds
)


tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

请注意,评估之间的步数取决于计算时间,而不是global_step