TFRecord-将png转换为字节

时间:2018-12-30 20:18:35

标签: tensorflow

创建tfrecord的代码:

def convert(self):
    with tf.python_io.TFRecordWriter(self.tfrecord_out) as writer:
        example = self._convert_image()
        writer.write(example.SerializeToString())

def _convert_image(self):
    for (path, label) in zip(self.image_paths, self.labels):
        label = int(label)
        # Read image data in terms of bytes
        with open(path, 'rb') as fid:
            png_bytes = fid.read()

        example = tf.train.Example(features=tf.train.Features(feature={
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[png_bytes]))
            }))
    return example

我的问题是,当我从文件中读取图像时,图像无法正确解码:

def parse(self, serialized):
    features = \
        {
            'image': tf.FixedLenFeature([], tf.string)
        }

    parsed_example = tf.parse_single_example(serialized=serialized,
                                                 features=features)

    image_raw = parsed_example['image']
    image = tf.image.decode_png(contents=image_raw, channels=3, dtype=tf.uint8)
    image = tf.cast(image, tf.float32)
    return image`

有人知道为什么吗?

rubbish picture

1 个答案:

答案 0 :(得分:0)

找到了解决方案,希望我的愚蠢错误会帮助其他人。

将张量重整为张量板的4个尺寸[batch_size, height, width, channels]时,我切换了宽度和高度。

正确的重塑代码是:

x_reshaped = session.run(tf.reshape(tensor=decoded_png_uint8, shape=[batch_size, height, width, channels], name="x_reshaped"))

但是我有shape=[batch_size, width, height, channels]。呃,好吧。每天都是上学日。

the correct output