在Tensorflow中,张量排名如何是动态的

时间:2017-06-26 04:29:32

标签: python tensorflow

我正在使用tf-slim库训练图像分类器。培训看起来很好,直到我有一个尺寸错误的某一点:

Traceback (most recent call last):
File "scripts/eval.py", line 190, in <module>
    tf.app.run()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "scripts/eval.py", line 186, in main
    timeout=None)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/evaluation.py", line 296, in evaluation_loop
    timeout=timeout)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/contrib/training/python/training/evaluation.py", line 455, in evaluate_repeatedly
    '%Y-%m-%d-%H:%M:%S', time.gmtime()))
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 521, in __exit__
    self._close_internal(exception_type)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 556, in _close_internal
    self._sess.close()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 791, in close
    self._sess.close()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 888, in close
    ignore_live_threads=True)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/six.py", line 686, in reraise
    raise value
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/training/queue_runner_impl.py", line 238, in _run
    enqueue_callable()
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1063, in _single_operation_run
    target_list_as_strings, status, None)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
    next(self.gen)
File "/home/tmattio/Envs/tf1.2rc2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be 4-dimensional[1,1,1000,823,3]
        [[Node: eval_image/ResizeArea = ResizeArea[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](eval_image/ExpandDims/_3123, eval_image/ResizeArea/size)]]
        [[Node: eval_image/ResizeArea/_3125 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_85_eval_image/ResizeArea", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

培训输入如何改变其排名?

我使用DataProvider API来获取输入:

provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    num_readers=FLAGS.num_readers,
    common_queue_capacity=20 * FLAGS.batch_size,
    common_queue_min=10 * FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])

然后处理它:

image = preprocess_for_train(image, train_image_size, train_image_size, None)

最后进行训练:

slim.learning.train(
    train_tensor,
    logdir=FLAGS.train_dir,
    master=FLAGS.master,
    is_chief=(FLAGS.task == 0),
    init_fn=_get_init_fn(),
    summary_op=summary_op,
    number_of_steps=FLAGS.max_number_of_steps,
    log_every_n_steps=FLAGS.log_every_n_steps,
    save_summaries_secs=FLAGS.save_summaries_secs,
    save_interval_secs=FLAGS.save_interval_secs,
    sync_optimizer=optimizer if FLAGS.sync_replicas else None)

以下是似乎引发错误的代码:

distorted_image = tf.expand_dims(distorted_image, 0)
distorted_image = tf.image.resize_area(distorted_image, [height, width],
                                       align_corners=False)
distorted_image = tf.squeeze(distorted_image, [0])

1 个答案:

答案 0 :(得分:0)

distorted_image = tf.expand_dims(distorted_image, 0) 

将添加一个额外的等级来转换你的张量

  • [1,1000,823,3]你原来的1张1000x823像素和3种颜色的图片

  • [1,1,1000,823,3]有1个额外的等级。

有关tf.expand_dims works go here

的详细信息

然后下一行 tf.image.resize_area 会抛出错误,因为张量的等级为5,而tf.image.resize_area只能处理4个等级。