我正在运行DNNClassifier,我正在训练时监控其准确性。来自contrib / learn的monitors.ValidationMonitor一直很好用,在我的实现中我定义了它:
validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)
然后使用来自:
的电话clf.fit(input_fn=lambda: input_fn(A, Cl2),
steps=1000, monitors=[validation_monitor])
其中:
clf = tensorflow.contrib.learn.DNNClassifier(...
这很好用。也就是说,验证监视器似乎已被弃用,并且类似的功能将被tf.train.SessionRunHook
替换。
我是TensorFlow的新手,对我而言,这样的替换实现看起来似乎并不重要。任何建议都非常感谢。同样,我需要在特定步骤后验证培训。 非常感谢提前。
答案 0 :(得分:16)
有一个名为monitors.replace_monitors_with_hooks()
的未记录的实用程序,它将监视器转换为钩子。该方法接受(i)可能包含监视器和钩子的列表,以及(ii)将使用钩子的Estimator,然后通过在每个Monitor周围包装SessionRunHook来返回钩子列表。
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
clf = tf.estimator.Estimator(...)
list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)
对于完全替换ValidationMonitor的问题,这不是一个真正的 true 解决方案 - 我们只是用一个非弃用的函数来包装它。但是,我可以说这对我来说是有用的,因为它保留了我需要的所有功能来自ValidationMonitor(即评估每个 n 步骤,提前停止使用指标等)
还有一件事 - 使用这个钩子你需要从tf.contrib.learn.Estimator
(只接受监视器)更新到更成熟的官方tf.estimator.Estimator
(只接受挂钩)。因此,您应该将分类器实例化为tf.estimator.DNNClassifier
,并使用其方法train()
进行训练(这只是fit()
的重新命名):
clf = tf.estimator.Estimator(...)
...
clf.train(
input_fn=...
...
hooks=hooks)
答案 1 :(得分:4)
我设法提出了一种按照建议扩展tf.train.SessionRunHook
的方法。
import tensorflow as tf
class ValidationHook(tf.train.SessionRunHook):
def __init__(self, model_fn, params, input_fn, checkpoint_dir,
every_n_secs=None, every_n_steps=None):
self._iter_count = 0
self._estimator = tf.estimator.Estimator(
model_fn=model_fn,
params=params,
model_dir=checkpoint_dir
)
self._input_fn = input_fn
self._timer = tf.train.SecondOrStepTimer(every_n_secs, every_n_steps)
self._should_trigger = False
def begin(self):
self._timer.reset()
self._iter_count = 0
def before_run(self, run_context):
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
def after_run(self, run_context, run_values):
if self._should_trigger:
self._estimator.evaluate(
self._input_fn
)
self._timer.update_last_triggered_step(self._iter_count)
self._iter_count += 1
并将其用作training_hook
中的Estimator.train
:
estimator.train(input_fn=_input_fn(...),
steps=num_epochs * num_steps_per_epoch,
hooks=[ValidationHook(...)])
它没有任何花哨的东西,ValidationMonitor
有早期停止等等,但这应该是一个开始。
答案 2 :(得分:0)
由于您要在每n_step之后验证一次训练,因此tf将使用最新保存的检查点。使用CheckpointSaverListener
保存检查点后,可以使用自定义CheckpointSaverHook
类添加评估步骤。
将模型分类器对象和评估输入函数传递给类
引用https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener
class ExampleCheckpointSaverListener(CheckpointSaverListener):
def __init(self):
self.classifier = classifier
self.eval_input_fn = eval_input_fn
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
self.your_tensor = ...
def before_save(self, session, global_step_value):
print('About to write a checkpoint')
eval_op = self.classifier.evaluate(input_fn=self.eval_input_fn)
print(eval_op)
def after_save(self, session, global_step_value):
print('Done writing checkpoint.')
def end(self, session, global_step_value):
print('Done with the session.')
...
listener = ExampleCheckpointSaverListener(Myclassifier, eval_input_fn )
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
答案 3 :(得分:0)
https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train https://www.tensorflow.org/api_docs/python/tf/train/CheckpointSaverListener
我使用检查点保存侦听器来监视训练,您使用的估计量可能有一个名为saving_listeners
的参数。每次创建检查点时都会调用它,这是您可以在估算器的config
中设置的参数。因此,钩子是:
class ValidationListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self._estimator = estimator
self._input_fn = input_fn
self._evaluation_loss = 9999
def after_save(self, run_context, run_values):
print("--- done writing checkpoint. ---")
evaluation = self._estimator.evaluate(input_fn=self._input_fn)
print(evaluation)
if evaluation['loss'] < self._evaluation_loss:
self._evaluation_loss = evaluation['loss']
else:
return True # Stop Training
训练时:
early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
hard_estimator,
metric_name='loss',
max_steps_without_decrease=500,
min_steps=100)
monitor_validation = ValidationListener(estimator=hard_estimator, input_fn=hard_validation_input_fn)
hard_estimator.train(
input_fn = train_input_fn,
hooks=[early_stopping],
steps=1000,
saving_listeners=[monitor_validation]
)
希望这会有所帮助。