为了用tfrecord 编码uint16 png ,我实现了以下代码。但是,它输出“ uint8”,如何对其进行修改?谢谢〜
def _bytes_list_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
def norm2bytes(value):
if isinstance(value, str) and six.PY3:
print "endcode"
return value.encode()
else:
# print value
return value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))
def image_to_tfexample(depth_data):
"""Converts depth to tf example.
Args:
depth_data: string of depth data.
Returns:
tf example of depth.
"""
return tf.train.Example(features=tf.train.Features(feature={
'image/depth/encoded': (
_bytes_list_feature(depth_data)),
'image/depth/format': _bytes_list_feature(
FLAGS.depth_format),
}))
以下解码代码用于指定TF-Examples的解码方式。解码器的dtype为uint8,而源一的dtype为uint16。 如何使用tfrecord编码和解码uint16图像。
keys_to_features = {
'image/depth/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/depth/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
items_to_handlers = {
'depth': tfexample_decoder.Image(
image_key='image/depth/encoded',
format_key='image/depth/format',
channels=1),
}
答案 0 :(得分:1)
不是专家,但是我快速浏览了一下代码。看起来像tfexample_decoder.Image目前仅支持uint8。您可能需要更新TF代码才能实现目标
image.decode_image运算符同时支持uint8和uint16图像,但是tfexample_decoder.Image并未将其传入。
如果您只是将dtype传递给decode_image,则可能会起作用
参考:
答案 1 :(得分:0)
在tfexample_decoder.py中,我进行了如下更改,并且可以正常工作。
完整代码
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: The tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`. If image
format is `raw`, all images are expected to be in this format, otherwise
this op can decode a mix of `jpg` and `png` formats.
Returns:
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_image():
"""Decodes a image based on the headers."""
return image_ops.decode_image(image_buffer, channels=self._channels, **dtype=self._dtype**)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
return image_ops.decode_jpeg(
image_buffer, channels=self._channels, dct_method=self._dct_method)
def check_jpeg():
"""Checks if an image is jpeg."""
# For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
# in order to feed the jpeg specify parameter 'dct_method'.
return control_flow_ops.cond(
image_ops.is_jpeg(image_buffer),
decode_jpeg,
decode_image,
name='cond_jpeg')
def decode_raw():
"""Decodes a raw image."""
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
**# image = control_flow_ops.case(
# pred_fn_pairs, default=check_jpeg, exclusive=True)
image = decode_image()**
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image