PNG TO TFRecord:错误:类型为numpy.ndarray,但预期之一:字节

时间:2018-11-19 20:36:45

标签: python tensorflow png tfrecord

我正在做一个.png文件进行tfrecord。

s

转换为示例文件:

def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

decode_png文件,不确定这部分

def _convert_to_example(filename, image_buffer, label, text, height, width):
example = tf.train.Example(features=tf.train.Features(feature={
    'image/height': _int64_feature(height),
    'image/width': _int64_feature(width),
    ........
    'image/encoded': _bytes_feature(image_buffer)})) #Error
    return example

process_image文件:

def decode_png(image_data):

    img_str = image_data.tostring()
    reconstructed_img_1d = np.fromstring(img_str, dtype=np.uint8)
    reconstructed_img = reconstructed_img_1d.reshape(image_data.shape)

    return reconstructed_img

主要,处理映像文件批处理:

def _process_image(filename):

    # Read the image file.
    with open(filename, 'r') as f:
        image_data = io.imread(f)

    # Decode the RGB PNG.
    image = decode_png(image_data)

    # Check that image converted to RGB
    assert len(image.shape) == 3
    height = image.shape[0]
    width = image.shape[1]
    assert image.shape[2] == 3

    return image_data, height, width

但是,有一个错误for i in files_in_shard: filename = filenames[i] label = labels[i] text = texts[i] image_buffer, height, width = _process_image(filename) example = _convert_to_example(filename, image_buffer, label, text, height, width) writer.write(example.SerializeToString()) shard_counter += 1 counter += 1

我应该如何处理?任何帮助都会很棒。 谢谢

1 个答案:

答案 0 :(得分:0)

好吧,我懂了...逻辑专业

def decode_png(image_data):

img_str = image_data.tostring()
reconstructed_img_1d = np.fromstring(img_str, dtype=np.uint8)
reconstructed_img = reconstructed_img_1d.reshape(image_data.shape)

return img_str


def _process_image(filename):
    # Read the image file.
    with open(filename, 'r') as f:
        image_data = io.imread(f)

    # Decode the RGB PNG.
    img_str = image_data.tostring()

    height = image_data.shape[0]
    width = image_data.shape[1]

    return img_str, height, width