我尝试在训练模型时将InMemoryEvaluatorHook与TPUEstimator结合使用以获取验证统计信息。使用estimator.train()
和estimator.evaluate()
的循环太昂贵了,因为它在每个纪元都重新构建了图,而不是尝试重用它(如本期文章https://github.com/tensorflow/tensorflow/issues/13895所述)。这是我使用的基本代码:
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
config=run_config,
use_tpu=True,
train_batch_size=self.batch_size,
eval_batch_size=self.batch_size,
predict_batch_size=self.batch_size,
params={})
train_fn = lambda params: input_fn(
'train', self.data_dir, batch_size=params['batch_size'], train=True)
val_fn = lambda params: input_fn(
'validation',
self.data_dir,
batch_size=params['batch_size'],
train=False)
train_hook = tf.contrib.estimator.InMemoryEvaluatorHook(
estimator,
val_fn,
steps=self.steps_per_val_epoch,
every_n_iter=self.steps_per_epoch)
estimator.train(
input_fn=train_fn,
steps=self.steps_per_epoch * self.max_num_training_epochs,
hooks=[
train_hook,
])
这导致以下错误:
Traceback (most recent call last):
File "dev/google_communicator/worker.py", line 160, in <module>
main()
File "dev/google_communicator/worker.py", line 133, in main
results = evaluator.eval(inputs, outputs)
File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 278, in eval
train_hook,
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
log_step_count_steps=log_step_count_steps) as mon_sess:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 921, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 631, in __init__
h.begin()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/estimator/python/estimator/hooks.py", line 135, in begin
self._input_fn, self._hooks, checkpoint_path=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1484, in _evaluate_build_graph
self._call_model_fn_eval(input_fn, self.config))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1520, in _call_model_fn_eval
features, labels, model_fn_lib.ModeKeys.EVAL, config)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn
features, labels, mode, config)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2631, in _model_fn
rendezvous=self._rendezvous[mode]),
KeyError: 'eval'
是否有更好的方法可以获取每个TPU的验证统计信息?如果没有,您应该如何进行验证?
编辑:我似乎已经克服了这个错误,只运行了estimator.train()
和estimator.evaluate()
而没有使用钩子,然后再使用钩子进行了完整的训练。不幸的是,在第一次评估之后,重新开始训练有一个错误:
Traceback (most recent call last):
File "dev/google_communicator/worker.py", line 160, in <module>
main()
File "dev/google_communicator/worker.py", line 133, in main
results = evaluator.eval(inputs, outputs)
File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 329, in eval
train_hook,
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1471, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 671, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1156, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1255, in run
raise six.reraise(*original_exc_info)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1240, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1312, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1328, in _do_run
run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1348, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: The TPU system has not been initialized.
[[{{node TPUReplicate/_compile/_14248540389241865347/_28}} = TPUCompile[NumDynamicShapes=0, Tguaranteed_constants=[], function=cluster_18378946049549366873_f15n_0[], metadata="\n\006\010...6\323\352L", num_computations=1, _device="/job:worker/replica:0/task:0/device:CPU:0"](^cluster/control_before/_0)]]
[[{{node tpu_compile_succeeded_assert/_1897752282630996029/_29_G679}} = _Recv[client_terminated=false, recv_device="/job:worker/replica:0/task:0/device:TPU:2", send_device="/job:worker/replica:0/task:0/device:CPU:0", send_device_incarnation=2337451129362726278, tensor_name="edge_174_tpu_compile_succeeded_assert/_1897752282630996029/_29", tensor_type=DT_FLOAT, _device="/job:worker/replica:0/task:0/device:TPU:2"]()]]
为澄清起见,在引发错误之前发生了以下事情:两次初始化训练并评估对估计器的调用,训练一个时期,对验证集进行评估。当估算器尝试重新开始训练时,将引发此异常。
此未解决的问题可能与以下问题有关:https://github.com/tensorflow/tensor2tensor/issues/1202