使用Dataset API从TFrecord读取图像并将其显示在Jupyter笔记本上

时间:2018-07-17 00:08:42

标签: tensorflow tensorflow-datasets tfrecord

我从图像文件夹创建了一个tfrecord,现在我想使用Dataset API遍历TFrecord文件中的条目,并在Jupyter笔记本上显示它们。但是我在读取tfrecord文件时遇到了问题。

我用来创建TFRecord的代码

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def generate_tfr(image_list):
    with tf.python_io.TFRecordWriter(output_path) as writer:
        for image in images:
            image_bytes = open(image,'rb').read()
            image_array = imread(image)
            image_shape = image_array.shape
            image_x, image_y, image_z = image_shape[0],image_shape[1], image_shape[2]

            data = {

              'image/bytes':_bytes_feature(image_bytes),
              'image/x':_int64_feature(image_x),
              'image/y':_int64_feature(image_y),
              'image/z':_int64_feature(image_z)
            }

            features = tf.train.Features(feature=data)
            example = tf.train.Example(features=features)
            serialized = example.SerializeToString()
            writer.write(serialized)

读取TFRecord的代码

#This code is incomplete and has many flaws. 
#Please give some suggestions in correcting this code if you can

def parse(serialized):
    features = \
    {
        'image/bytes': tf.FixedLenFeature([], tf.string),
        'image/x': tf.FixedLenFeature([], tf.int64),
        'image/y': tf.FixedLenFeature([], tf.int64),
        'image/z': tf.FixedLenFeature([], tf.int64)
    }

    parsed_example = tf.parse_single_example(serialized=serialized,features=features)
    image = parsed_example['image/bytes']
    image = tf.decode_raw(image,tf.uint8)
    x = parsed_example['image/x'] # breadth
    y = parsed_example['image/y'] # height
    z = parsed_example['image/z'] # depth
    image = tf.cast(image,tf.float32)

    # how can I reshape image tensor here? tf.reshape throwing some weird errors.

    return {'image':image,'x':x,'y':y,'z':z}



dataset = tf.data.TFRecordDataset([output_path])
dataset.map(parse)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
epoch = 1

with tf.Session() as sess:
    for _ in range(epoch):
    img = next_element.eval()
    print(img)
    # when I print image, it shows byte code.
    # How can I convert it to numpy array and then show image on my jupyter notebook ?

我以前从未使用过任何这些,而且我一直在阅读TFRecords。请回答如何遍历TFrecord的内容并将其显示在Jupyter笔记本上。随时纠正/优化这两段代码。那对我有很大帮助。

1 个答案:

答案 0 :(得分:1)

这是您想要的吗?我认为一旦将其转换为numpy数组,您就可以使用PIL.Image在jupyter笔记本中显示。

将tf记录转换为numpy => How can I convert TFRecords into numpy arrays?

将numpy数组显示为图像 https://gist.github.com/kylemcdonald/2f1b9a255993bf9b2629