无法在TPU上训练Deeplab模型

时间:2019-01-11 11:55:36

标签: tensorflow deeplab tpu

我尝试遍历the README 但是遇到了各种各样的问题。该程序找到了初始检查点文件,但似乎无法开始训练。我使用了Pascal_voc数据集(通过提供的脚本转换为tfrecord文件)作为自述文件,并使用a model pre-trained on ImageNet作为初始检查点。 我也尝试删除init_checkpoint以从头开始训练,但出现了相同的错误。 可以请人看看这个错误并给我一些建议吗?

> python main.py \
> --mode=eval \
> --num_shards=8 \
> --alsologtostderr=true \
> --model_dir=${BUCKET}/CKPT/model.ckpt \
> --dataset_dir=${BUCKET}/DATA \
> --model_variant=resnet_v1_101_beta \
> --image_pyramid=1. \
> --aspp_with_separable_conv=false \
> --multi_grid=1 \
> --multi_grid=2 \
> --multi_grid=4 \
> --decoder_use_separable_conv=false

I0111 13:36:00.646313 139626179581696 tf_logging.py:115] Found an init checkpoint.
I0111 13:36:05.113049 139626179581696 tf_logging.py:115] Create CheckpointSaverHook.
I0111 13:36:05.701862 139626179581696 tf_logging.py:115] Done calling model_fn.
I0111 13:36:12.406670 139626179581696 tf_logging.py:115] TPU job name worker
I0111 13:36:14.939769 139626179581696 tf_logging.py:115] Graph was finalized.
I0111 13:36:16.907617 139626179581696 tf_logging.py:115] Error recorded from training_loop: Input 1 of node CrossReplicaSum was passed int32 from CrossReplicaSum/group_assignment:0 incompatible with expected INVALID.
I0111 13:36:16.907968 139626179581696 tf_logging.py:115] training_loop marked as finished
W0111 13:36:16.908082 139626179581696 tf_logging.py:120] Reraising captured error
Traceback (most recent call last):
  File "main.py", line 304, in <module>
    app.run(main)
  File "/home/al/.local/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/al/.local/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "main.py", line 259, in main
    max_steps=FLAGS.train_steps)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
    rendezvous.raise_errors()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
    six.reraise(typ, value, traceback)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
    saving_listeners=saving_listeners
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
    saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 921, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 643, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1107, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1112, in _create_session
    return self._sess_creator.create_session()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 800, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 566, in create_session
    init_fn=self._scaffold.init_fn)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 294, in prepare_session
    sess.run(init_op, feed_dict=init_feed_dict)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1328, in _do_run
    run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1348, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 1 of node CrossReplicaSum was passed int32 from CrossReplicaSum/group_assignment:0 incompatible with expected INVALID.

0 个答案:

没有答案