使用Estimators时,使用tf.train.SessionRunHook替换验证监视器

时间:2017-06-28 04:33:11

标签: validation tensorflow

我正在运行DNNClassifier,我正在训练时监控其准确性。来自contrib / learn的monitors.ValidationMonitor一直很好用,在我的实现中我定义了它:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

然后使用来自:

的电话
clf.fit(input_fn=lambda: input_fn(A, Cl2),
            steps=1000, monitors=[validation_monitor])

其中:

clf = tensorflow.contrib.learn.DNNClassifier(...

这很好用。也就是说,验证监视器似乎已被弃用,并且类似的功能将被tf.train.SessionRunHook替换。

我是TensorFlow的新手,对我而言,这样的替换实现看起来似乎并不重要。任何建议都非常感谢。同样,我需要在特定步骤后验证培训。 非常感谢提前。

4 个答案:

答案 0 :(得分:16)

有一个名为monitors.replace_monitors_with_hooks()的未记录的实用程序,它将监视器转换为钩子。该方法接受(i)可能包含监视器和钩子的列表,以及(ii)将使用钩子的Estimator,然后通过在每个Monitor周围包装SessionRunHook来返回钩子列表。

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

clf = tf.estimator.Estimator(...)

list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

对于完全替换ValidationMonitor的问题,这不是一个真正的 true 解决方案 - 我们只是用一个非弃用的函数来包装它。但是,我可以说这对我来说是有用的,因为它保留了我需要的所有功能来自ValidationMonitor(即评估每个 n 步骤,提前停止使用指标等)

还有一件事 - 使用这个钩子你需要从tf.contrib.learn.Estimator(只接受监视器)更新到更成熟的官方tf.estimator.Estimator(只接受挂钩)。因此,您应该将分类器实例化为tf.estimator.DNNClassifier,并使用其方法train()进行训练(这只是fit()的重新命名):

clf = tf.estimator.Estimator(...)

...

clf.train(
    input_fn=...
    ...
    hooks=hooks)

答案 1 :(得分:4)

我设法提出了一种按照建议扩展tf.train.SessionRunHook的方法。

import tensorflow as tf


class ValidationHook(tf.train.SessionRunHook):
    def __init__(self, model_fn, params, input_fn, checkpoint_dir,
                 every_n_secs=None, every_n_steps=None):
        self._iter_count = 0
        self._estimator = tf.estimator.Estimator(
            model_fn=model_fn,
            params=params,
            model_dir=checkpoint_dir
        )
        self._input_fn = input_fn
        self._timer = tf.train.SecondOrStepTimer(every_n_secs, every_n_steps)
        self._should_trigger = False

    def begin(self):
        self._timer.reset()
        self._iter_count = 0

    def before_run(self, run_context):
        self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)

    def after_run(self, run_context, run_values):
        if self._should_trigger:
            self._estimator.evaluate(
                self._input_fn
            )
            self._timer.update_last_triggered_step(self._iter_count)
        self._iter_count += 1

并将其用作training_hook中的Estimator.train

estimator.train(input_fn=_input_fn(...),
                steps=num_epochs * num_steps_per_epoch,
                hooks=[ValidationHook(...)])

它没有任何花哨的东西,ValidationMonitor有早期停止等等,但这应该是一个开始。

答案 2 :(得分:0)

由于您要在每n_step之后验证一次训练,因此tf将使用最新保存的检查点。使用CheckpointSaverListener保存检查点后,可以使用自定义CheckpointSaverHook类添加评估步骤。 将模型分类器对象和评估输入函数传递给类

引用https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener

class ExampleCheckpointSaverListener(CheckpointSaverListener):
  def __init(self):
    self.classifier = classifier
    self.eval_input_fn = eval_input_fn

  def begin(self):
    # You can add ops to the graph here.
    print('Starting the session.')
    self.your_tensor = ...

  def before_save(self, session, global_step_value):
    print('About to write a checkpoint')
    eval_op = self.classifier.evaluate(input_fn=self.eval_input_fn)
    print(eval_op)

  def after_save(self, session, global_step_value):
    print('Done writing checkpoint.')

  def end(self, session, global_step_value):
    print('Done with the session.')

...
listener = ExampleCheckpointSaverListener(Myclassifier, eval_input_fn )
saver_hook = tf.train.CheckpointSaverHook(
    checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):

答案 3 :(得分:0)

https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener

我使用检查点保存侦听器来监视训练,您使用的估计量可能有一个名为saving_listeners的参数。每次创建检查点时都会调用它,这是您可以在估算器的config中设置的参数。因此,钩子是:

class ValidationListener(tf.train.CheckpointSaverListener):
    def __init__(self, estimator, input_fn):
        self._estimator = estimator
        self._input_fn = input_fn
        self._evaluation_loss = 9999

    def after_save(self, run_context, run_values):
        print("--- done writing checkpoint. ---")
        evaluation = self._estimator.evaluate(input_fn=self._input_fn)
        print(evaluation)
        if evaluation['loss'] < self._evaluation_loss:
            self._evaluation_loss = evaluation['loss']
        else:
            return True # Stop Training

训练时:

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    hard_estimator,
    metric_name='loss',
    max_steps_without_decrease=500,
    min_steps=100)

monitor_validation = ValidationListener(estimator=hard_estimator, input_fn=hard_validation_input_fn)

hard_estimator.train(
    input_fn = train_input_fn,
    hooks=[early_stopping],
    steps=1000,
    saving_listeners=[monitor_validation]
)

希望这会有所帮助。