我正在使用Tensorflow的Estimator
训练模型,执行评估后经过2600步后它突然停止训练。是不是应该继续训练直到最后一个纪元结束?
def train():
train_input_func = lambda: input_fn(mode='train')
eval_input_func = lambda: input_fn(mode='eval')
est_conf = tf.estimator.RunConfig(cfg.model_dir, save_checkpoints_secs=120)
estimator = tf.estimator.Estimator(model_fn, cfg.model_dir, est_conf)
Path(estimator.eval_dir()).mkdir(parents=True, exist_ok=True)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_func)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_func, throttle_secs=120)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
train()
这是input_fn
函数:
def input_fn(mode=None):
data_generator = lambda: data_loader.data_generator(mode=mode)
dataset = tf.data.Dataset.from_generator(data_generator,
output_types=(tf.int32, tf.int32),
output_shapes=([None], [None]))
if mode is 'train':
dataset.shuffle(cfg.shuffle_buffer).repeat(1000)
dataset = dataset.padded_batch(cfg.batch_size, padded_shapes=([None],[None])).prefetch(1)
return dataset
答案 0 :(得分:0)
第一,您需要在TrainSpec定义中指定max_stps,如下所示:
train_spec = tf.estimator.TrainSpec(input_fn=train_input_func, max_steps=num_steps_you_specify)
第二 当input_fn抛出“ OutOfRangeError”时,训练过程将停止,在这种情况下,max_step将无法按预期工作。因此,为了使训练贯穿整个时期,您需要像以下所示指定input_fn:
dataset = dataset.repeat()# don't specify any number in the repeat()
希望这会对您有所帮助。
答案 1 :(得分:0)
问题是我没有分配dataset.shuffle(cfg.shuffle_buffer).repeat(1000)
。这样可以解决问题:
dataset = dataset.shuffle(cfg.shuffle_buffer).repeat(1000)