tf.random_crop不能与通配符一起使用?

时间:2018-06-07 04:58:02

标签: python tensorflow

我正在使用Python3.5 + TensorFlow1.8构建图像翻译网络。 对于数据扩充,我尝试将tf.random_crop()与下面的通配符一起使用:

# input images
A = tf.placeholder(tf.float, shape=(None, 480, 640, 3))
B = tf.placeholder(tf.float, shape=(None, 480, 640, 3))

# images concatenation to crop on the same random seed
AB = tf.concat([A, B], 3)

# random cropping with wildcard for batch_size specification
AB_cropped = tf.random_crop(AB, [-1, 480, 480, 4])

# cropped images
A_ = AB_cropped[:,:,:,:3]
B_ = AB_cropped[:,:,:,3:]

...

对于每次运行,它不适用于某些不同的错误(有时会使用错误的结果)。 发生的错误就像下面这样:

InvalidArgumentError(参见上面的回溯):预期在[0,1]中开始[0],但得到2

Traceback (most recent call last):
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1350, in _do_call
    return fn(*args)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1329, in _run_fn
    status, run_metadata)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin[0] in [0, 1], but got 2
     [[Node: preprocess/random_crop = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](preprocess/concat, preprocess/random_crop/mod, preprocess/random_crop/size)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "alpha_gan_based.py", line 247, in <module>
    _, eg_loss = sess.run([hybrid_op, hybrid_loss], {image:image_batch, depth:depth_batch, z_prior:sample_z()})
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1128, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1344, in _do_run
    options, run_metadata)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1363, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin[0] in [0, 1], but got 2
     [[Node: preprocess/random_crop = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](preprocess/concat, preprocess/random_crop/mod, preprocess/random_crop/size)]]

Caused by op 'preprocess/random_crop', defined at:
  File "alpha_gan_based.py", line 156, in <module>
    cropped = (tf.random_crop(merged, [-1, 192, 192, 4]) / 255) * 2 - 1
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/random_ops.py", line 316, in random_crop
    return array_ops.slice(value, offset, size, name=name)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py", line 625, in slice
    return gen_array_ops._slice(input_, begin, size, name=name)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 4687, in _slice
    "Slice", input=input, begin=begin, size=size, name=name)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
    op_def=op_def)
  File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Expected begin[0] in [0, 1], but got 2
     [[Node: preprocess/random_crop = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](preprocess/concat, preprocess/random_crop/mod, preprocess/random_crop/size)]]
除非所有指定的输入大小都非零,否则

重塑不能推断空张量的缺失输入大小

Traceback (most recent call last):
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1350, in _do_call
         return fn(*args)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1329, in _run_fn
         status, run_metadata)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
         c_api.TF_GetCode(self.status.status))
     tensorflow.python.framework.errors_impl.InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
         [[Node: encoder/flatten/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](encoder/Relu_6, encoder/flatten/Reshape/shape)]]

     During handling of the above exception, another exception occurred:

     Traceback (most recent call last):
       File "alpha_gan_based.py", line 248, in <module>
         _, eg_loss = sess.run([hybrid_op, hybrid_loss], {image:image_batch, depth:depth_batch, z_prior:sample_z()})
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 895, in run
         run_metadata_ptr)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1128, in _run
         feed_dict_tensor, options, run_metadata)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1344, in _do_run
         options, run_metadata)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1363, in _do_call
         raise type(e)(node_def, op, message)
     tensorflow.python.framework.errors_impl.InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
         [[Node: encoder/flatten/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](encoder/Relu_6, encoder/flatten/Reshape/shape)]]

     Caused by op 'encoder/flatten/Reshape', defined at:
       File "alpha_gan_based.py", line 163, in <module>
         z_encoded, intermidiate = encoder(x_real_image)
       File "alpha_gan_based.py", line 54, in encoder
         x = tf.layers.flatten(x)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 414, in flatten
         return layer.apply(inputs)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 762, in apply
         return self.__call__(inputs, *args, **kwargs)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 652, in __call__
         outputs = self.call(inputs, *args, **kwargs)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 376, in call
         outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3997, in reshape
         "Reshape", tensor=tensor, shape=shape, name=name)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
         op_def=op_def)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
         op_def=op_def)
       File "/usr/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1625, in __init__
         self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

     InvalidArgumentError (see above for traceback): Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
         [[Node: encoder/flatten/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](encoder/Relu_6, encoder/flatten/Reshape/shape)]]

如果我指定batch_size而不是通配符,则无错误。 这可能不是我的关键问题,因为图像翻译网络通常与batch_size=1一起使用。 但是,我担心缺乏batch_size灵活性。

这是一个不可避免的问题吗? 或者还有另一种指定通配符的方法?

注意: 有些文章涉及与tf.random_crop无关的相同错误。 他们说“这是GPU缺乏问题!”,对我来说这是一个不可避免的问题...... :(

1 个答案:

答案 0 :(得分:1)

random_crop不适用于通配符。您可以执行以下操作,而不是使用-1作为未知维度:

AB_cropped = tf.random_crop(AB, [tf.shape(AB)[0], 480, 480, 4])

请注意,tf.shape具有动态形状,在运行时是众所周知的。