用tf.estimator提早停止,怎么样?

时间:2017-11-06 12:28:11

标签: python tensorflow neural-network keras tensorflow-estimator

我在TensorFlow 1.4中使用tf.estimator并且tf.estimator.train_and_evaluate很棒,但我需要提前停止。添加它的首选方法是什么?

我假设某处有一些tf.train.SessionRunHook。我看到有一个带有ValidationMonitor的旧的contrib包似乎已经提前停止了,但它似乎不再是1.4了。或者将来的首选方式是依靠tf.keras(早期停止真的很容易)而不是tf.estimator/tf.layers/tf.data,或许?

4 个答案:

答案 0 :(得分:22)

好消息! tf.estimator现在已经在master上提供了早期停止支持,并且看起来将在1.10中。

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

答案 1 :(得分:2)

是的,有tf.train.StopAtStepHook

  

在执行了许多步骤或达到最后一步之后,此挂钩请求停止。只能指定两个选项中的一个。

您还可以根据步骤结果对其进行扩展并实施自己的停止策略。

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()

答案 2 :(得分:1)

另一个不使用钩子的选项是创建一个tf.contrib.learn.Experiment(即使在contrib中,它似乎也支持新的tf.estimator.Estimator)。

然后使用适当定制的continuous_train_and_eval通过(显然是实验性的)方法[{1}}进行训练。

根据tensorflow文档,continuous_eval_predicate_fn

  

判断每次迭代后是否继续eval的谓词函数。

并使用上次评估运行中的continuous_eval_predicate_fn进行调用。对于提前停止,使用自定义函数将状态保持为当前最佳结果和计数器,并在达到提前停止条件时返回eval_results

注意补充:这种方法将使用弃用的方法w / tensorflow 1.7(从该版本开始不推荐使用tf.contrib.learn:https://www.tensorflow.org/api_docs/python/tf/contrib/learn

答案 3 :(得分:1)

首先,您必须将损失命名为提前停止呼叫。如果您的损失变量被命名为"损失"在估算器中,行

copyloss = tf.identity(loss, name="loss")
它下面的

会起作用。

然后,使用此代码创建一个钩子。

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

这将指数平滑损失验证与其最低值进行比较,如果它的容差更高,则会停止训练。如果它太早停止,提高公差和平滑将使其稍后停止。保持平滑低于一,否则永远不会停止。

如果要根据不同的条件停止,可以使用其他内容替换after_run中的逻辑。

现在,将此挂钩添加到评估规范中。您的代码应如下所示:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

重要提示:函数run_context.request_stop()在train_and_evaluate调用中被破坏,并且不会停止训练。所以,我提出了一个值错误来停止训练。所以你必须将train_and_evaluate调用包装在try catch块中,如下所示:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

如果您不这样做,代码将在培训停止时因错误而崩溃。