我正在使用tf-slim库训练图像分类器。培训看起来很好,直到我有一个尺寸错误的某一点:
Traceback (most recent call last):
File "scripts/eval.py", line 190, in <module>
tf.app.run()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "scripts/eval.py", line 186, in main
timeout=None)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/evaluation.py", line 296, in evaluation_loop
timeout=timeout)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/contrib/training/python/training/evaluation.py", line 455, in evaluate_repeatedly
'%Y-%m-%d-%H:%M:%S', time.gmtime()))
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 521, in __exit__
self._close_internal(exception_type)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 556, in _close_internal
self._sess.close()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 791, in close
self._sess.close()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 888, in close
ignore_live_threads=True)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/six.py", line 686, in reraise
raise value
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/queue_runner_impl.py", line 238, in _run
enqueue_callable()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1063, in _single_operation_run
target_list_as_strings, status, None)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be 4-dimensional[1,1,1000,823,3]
[[Node: eval_image/ResizeArea = ResizeArea[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](eval_image/ExpandDims/_3123, eval_image/ResizeArea/size)]]
[[Node: eval_image/ResizeArea/_3125 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_85_eval_image/ResizeArea", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
培训输入如何改变其排名?
我使用DataProvider API来获取输入:
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=FLAGS.num_readers,
common_queue_capacity=20 * FLAGS.batch_size,
common_queue_min=10 * FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
然后处理它:
image = preprocess_for_train(image, train_image_size, train_image_size, None)
最后进行训练:
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
init_fn=_get_init_fn(),
summary_op=summary_op,
number_of_steps=FLAGS.max_number_of_steps,
log_every_n_steps=FLAGS.log_every_n_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
sync_optimizer=optimizer if FLAGS.sync_replicas else None)
以下是似乎引发错误的代码:
distorted_image = tf.expand_dims(distorted_image, 0)
distorted_image = tf.image.resize_area(distorted_image, [height, width],
align_corners=False)
distorted_image = tf.squeeze(distorted_image, [0])
答案 0 :(得分:0)
distorted_image = tf.expand_dims(distorted_image, 0)
将添加一个额外的等级来转换你的张量
到
有关tf.expand_dims works go here
的详细信息然后下一行 tf.image.resize_area 会抛出错误,因为张量的等级为5,而tf.image.resize_area只能处理4个等级。