用tfrecord(depth_data)编码uint16 png

时间:2018-10-15 07:19:16

标签: tensorflow

为了用tfrecord 编码uint16 png ,我实现了以下代码。但是,它输出“ uint8”,如何对其进行修改?谢谢〜

def _bytes_list_feature(values):
"""Returns a TF-Feature of bytes.

Args:
values: A string.

Returns:
A TF-Feature.
"""
  def norm2bytes(value):
    if isinstance(value, str) and six.PY3:
      print "endcode"
      return value.encode()
    else:
      # print value
      return value

  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))


def image_to_tfexample(depth_data):
  """Converts depth to tf example.

  Args:
    depth_data: string of depth data.

  Returns:
    tf example of depth.
  """
  return tf.train.Example(features=tf.train.Features(feature={
    'image/depth/encoded': (
      _bytes_list_feature(depth_data)),
    'image/depth/format': _bytes_list_feature(
      FLAGS.depth_format),
  }))

以下解码代码用于指定TF-Examples的解码方式。解码器的dtype为uint8,而源一的dtype为uint16。 如何使用tfrecord编码和解码uint16图像。

keys_to_features = {
  'image/depth/encoded': tf.FixedLenFeature(
      (), tf.string, default_value=''),
  'image/depth/format': tf.FixedLenFeature(
      (), tf.string, default_value='png'),
}
items_to_handlers = {
  'depth': tfexample_decoder.Image(
      image_key='image/depth/encoded',
      format_key='image/depth/format',
      channels=1),
}

2 个答案:

答案 0 :(得分:1)

不是专家,但是我快速浏览了一下代码。看起来像tfexample_decoder.Image目前仅支持uint8。您可能需要更新TF代码才能实现目标

image.decode_image运算符同时支持uint8和uint16图像,但是tfexample_decoder.Image并未将其传入。

如果您只是将dtype传递给decode_image,则可能会起作用

参考:

答案 1 :(得分:0)

在tfexample_decoder.py中,我进行了如下更改,并且可以正常工作。

  1. image_ops.decode_image(image_buffer,channels = self._channels,dtype = self._dtype)
  2. 将image = control_flow_ops.case(pred_fn_pairs,default = check_jpeg,Exclusive = True)替换为image = encode_image()

完整代码

def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.

Args:
  image_buffer: The tensor representing the encoded image tensor.
  image_format: The image format for the image in `image_buffer`. If image
    format is `raw`, all images are expected to be in this format, otherwise
    this op can decode a mix of `jpg` and `png` formats.

Returns:
  A tensor that represents decoded image of self._shape, or
  (?, ?, self._channels) if self._shape is not specified.
"""

def decode_image():
  """Decodes a image based on the headers."""
  return image_ops.decode_image(image_buffer, channels=self._channels, **dtype=self._dtype**)

def decode_jpeg():
  """Decodes a jpeg image with specified '_dct_method'."""
  return image_ops.decode_jpeg(
      image_buffer, channels=self._channels, dct_method=self._dct_method)

def check_jpeg():
  """Checks if an image is jpeg."""
  # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
  # in order to feed the jpeg specify parameter 'dct_method'.
  return control_flow_ops.cond(
      image_ops.is_jpeg(image_buffer),
      decode_jpeg,
      decode_image,
      name='cond_jpeg')

def decode_raw():
  """Decodes a raw image."""
  return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

pred_fn_pairs = {
    math_ops.logical_or(
        math_ops.equal(image_format, 'raw'),
        math_ops.equal(image_format, 'RAW')): decode_raw,
}

**# image = control_flow_ops.case(
#     pred_fn_pairs, default=check_jpeg, exclusive=True)

image = decode_image()**

image.set_shape([None, None, self._channels])
if self._shape is not None:
  image = array_ops.reshape(image, self._shape)

return image