我尝试遍历the README 但是遇到了各种各样的问题。该程序找到了初始检查点文件,但似乎无法开始训练。我使用了Pascal_voc数据集(通过提供的脚本转换为tfrecord文件)作为自述文件,并使用a model pre-trained on ImageNet作为初始检查点。 我也尝试删除init_checkpoint以从头开始训练,但出现了相同的错误。 可以请人看看这个错误并给我一些建议吗?
> python main.py \
> --mode=eval \
> --num_shards=8 \
> --alsologtostderr=true \
> --model_dir=${BUCKET}/CKPT/model.ckpt \
> --dataset_dir=${BUCKET}/DATA \
> --model_variant=resnet_v1_101_beta \
> --image_pyramid=1. \
> --aspp_with_separable_conv=false \
> --multi_grid=1 \
> --multi_grid=2 \
> --multi_grid=4 \
> --decoder_use_separable_conv=false
I0111 13:36:00.646313 139626179581696 tf_logging.py:115] Found an init checkpoint.
I0111 13:36:05.113049 139626179581696 tf_logging.py:115] Create CheckpointSaverHook.
I0111 13:36:05.701862 139626179581696 tf_logging.py:115] Done calling model_fn.
I0111 13:36:12.406670 139626179581696 tf_logging.py:115] TPU job name worker
I0111 13:36:14.939769 139626179581696 tf_logging.py:115] Graph was finalized.
I0111 13:36:16.907617 139626179581696 tf_logging.py:115] Error recorded from training_loop: Input 1 of node CrossReplicaSum was passed int32 from CrossReplicaSum/group_assignment:0 incompatible with expected INVALID.
I0111 13:36:16.907968 139626179581696 tf_logging.py:115] training_loop marked as finished
W0111 13:36:16.908082 139626179581696 tf_logging.py:120] Reraising captured error
Traceback (most recent call last):
File "main.py", line 304, in <module>
app.run(main)
File "/home/al/.local/lib/python2.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/home/al/.local/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "main.py", line 259, in main
max_steps=FLAGS.train_steps)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
log_step_count_steps=log_step_count_steps) as mon_sess:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 921, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 643, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1107, in __init__
_WrappedSession.__init__(self, self._create_session())
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1112, in _create_session
return self._sess_creator.create_session()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 800, in create_session
self.tf_sess = self._session_creator.create_session()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 566, in create_session
init_fn=self._scaffold.init_fn)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 294, in prepare_session
sess.run(init_op, feed_dict=init_feed_dict)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1328, in _do_run
run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1348, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 1 of node CrossReplicaSum was passed int32 from CrossReplicaSum/group_assignment:0 incompatible with expected INVALID.