将InMemoryEvaluatorHook与TPU一起使用会引发异常

时间:2019-03-08 21:56:15

标签: tensorflow tensorflow-estimator tpu

我尝试在训练模型时将InMemoryEvaluatorHook与TPUEstimator结合使用以获取验证统计信息。使用estimator.train()estimator.evaluate()的循环太昂贵了,因为它在每个纪元都重新构建了图,而不是尝试重用它(如本期文章https://github.com/tensorflow/tensorflow/issues/13895所述)。这是我使用的基本代码:

estimator = tf.contrib.tpu.TPUEstimator(
    model_fn=model_fn,
    config=run_config,
    use_tpu=True,
    train_batch_size=self.batch_size,
    eval_batch_size=self.batch_size,
    predict_batch_size=self.batch_size,
    params={})

train_fn = lambda params: input_fn(
    'train', self.data_dir, batch_size=params['batch_size'], train=True)
val_fn = lambda params: input_fn(
    'validation',
    self.data_dir,
    batch_size=params['batch_size'],
    train=False)
train_hook = tf.contrib.estimator.InMemoryEvaluatorHook(
    estimator,
    val_fn,
    steps=self.steps_per_val_epoch,
    every_n_iter=self.steps_per_epoch)
estimator.train(
    input_fn=train_fn,
    steps=self.steps_per_epoch * self.max_num_training_epochs,
    hooks=[
        train_hook,
    ])

这导致以下错误:

Traceback (most recent call last):
  File "dev/google_communicator/worker.py", line 160, in <module>
    main()
  File "dev/google_communicator/worker.py", line 133, in main
    results = evaluator.eval(inputs, outputs)
  File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 278, in eval
    train_hook,
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
    rendezvous.raise_errors()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
    six.reraise(typ, value, traceback)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
    saving_listeners=saving_listeners
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
    saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 921, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 631, in __init__
    h.begin()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/estimator/python/estimator/hooks.py", line 135, in begin
    self._input_fn, self._hooks, checkpoint_path=None)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1484, in _evaluate_build_graph
    self._call_model_fn_eval(input_fn, self.config))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1520, in _call_model_fn_eval
    features, labels, model_fn_lib.ModeKeys.EVAL, config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn
    features, labels, mode, config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2631, in _model_fn
    rendezvous=self._rendezvous[mode]),
KeyError: 'eval'

是否有更好的方法可以获取每个TPU的验证统计信息?如果没有,您应该如何进行验证?

编辑:我似乎已经克服了这个错误,只运行了estimator.train()estimator.evaluate()而没有使用钩子,然后再使用钩子进行了完整的训练。不幸的是,在第一次评估之后,重新开始训练有一个错误:

Traceback (most recent call last):
  File "dev/google_communicator/worker.py", line 160, in <module>
    main()
  File "dev/google_communicator/worker.py", line 133, in main
    results = evaluator.eval(inputs, outputs)
  File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 329, in eval
    train_hook,
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
    rendezvous.raise_errors()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
    six.reraise(typ, value, traceback)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
    saving_listeners=saving_listeners
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
    saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1471, in _train_with_estimator_spec
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 671, in run
    run_metadata=run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1156, in run
    run_metadata=run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1255, in run
    raise six.reraise(*original_exc_info)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1240, in run
    return self._sess.run(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1312, in run
    run_metadata=run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
    return self._sess.run(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1328, in _do_run
    run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1348, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: The TPU system has not been initialized.
         [[{{node TPUReplicate/_compile/_14248540389241865347/_28}} = TPUCompile[NumDynamicShapes=0, Tguaranteed_constants=[], function=cluster_18378946049549366873_f15n_0[], metadata="\n\006\010...6\323\352L", num_computations=1, _device="/job:worker/replica:0/task:0/device:CPU:0"](^cluster/control_before/_0)]]
         [[{{node tpu_compile_succeeded_assert/_1897752282630996029/_29_G679}} = _Recv[client_terminated=false, recv_device="/job:worker/replica:0/task:0/device:TPU:2", send_device="/job:worker/replica:0/task:0/device:CPU:0", send_device_incarnation=2337451129362726278, tensor_name="edge_174_tpu_compile_succeeded_assert/_1897752282630996029/_29", tensor_type=DT_FLOAT, _device="/job:worker/replica:0/task:0/device:TPU:2"]()]]

为澄清起见,在引发错误之前发生了以下事情:两次初始化训练并评估对估计器的调用,训练一个时期,对验证集进行评估。当估算器尝试重新开始训练时,将引发此异常。

此未解决的问题可能与以下问题有关:https://github.com/tensorflow/tensor2tensor/issues/1202

0 个答案:

没有答案