我一直在使用自己的Estimator / Experiment代码超过一年,但我想最终跳上Dataset + Estimator的潮流。
我想做以下事情:
for _ in range(N):
estimator.train(train_input_fn, steps=1000)
estimator.evaluate(validation_input_fn)
train_input_fn
创建tf.data.Dataset
永久循环训练集,validation_input_fn
创建tf.data.Dataset
,执行验证集的一次传递。
train()
跨呼叫保持train_input_fn
的状态(即,如果引用匹配,则只调用一次)?这是人们如何使用Estimator进行训练循环吗?
答案 0 :(得分:3)
正如我在上面的评论中提到的,看起来它不会在调用estimator.train()
时保存状态。
我要使用的解决方案,可能是预期的方法,是将评估侦听器传递给estimator.train()
。例如,
class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self.estimator = estimator
self.input_fn = input_fn
def after_save(self, session, global_step):
self.estimator.evaluate(self.input_fn)
estimator.train(
input_fn=lambda:_train_input_fn(...),
max_steps=N,
saving_listeners=[
EvalCheckpointSaverListener(
estimator,
lambda:_eval_input_fn(...),
),
],
)
答案 1 :(得分:2)
您现在还可以使用Estimator
API中的train_and_evaluate
方法。
这是它的工作原理:
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=self.model_dir,
params=params
)
train_spec = tf.estimator.TrainSpec(input_fn, max_steps=N)
eval_spec = tf.estimator.EvalSpec(
validation_input_fn,
steps=None,
start_delay_secs=120, # start evaluating 120 seconds after beginning of training
throttle_secs=600 # evaluate every 600 seconds
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
请注意,评估之间的步数取决于计算时间,而不是global_step
。