我从图像文件夹创建了一个tfrecord,现在我想使用Dataset API遍历TFrecord文件中的条目,并在Jupyter笔记本上显示它们。但是我在读取tfrecord文件时遇到了问题。
我用来创建TFRecord的代码
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def generate_tfr(image_list):
with tf.python_io.TFRecordWriter(output_path) as writer:
for image in images:
image_bytes = open(image,'rb').read()
image_array = imread(image)
image_shape = image_array.shape
image_x, image_y, image_z = image_shape[0],image_shape[1], image_shape[2]
data = {
'image/bytes':_bytes_feature(image_bytes),
'image/x':_int64_feature(image_x),
'image/y':_int64_feature(image_y),
'image/z':_int64_feature(image_z)
}
features = tf.train.Features(feature=data)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
读取TFRecord的代码
#This code is incomplete and has many flaws.
#Please give some suggestions in correcting this code if you can
def parse(serialized):
features = \
{
'image/bytes': tf.FixedLenFeature([], tf.string),
'image/x': tf.FixedLenFeature([], tf.int64),
'image/y': tf.FixedLenFeature([], tf.int64),
'image/z': tf.FixedLenFeature([], tf.int64)
}
parsed_example = tf.parse_single_example(serialized=serialized,features=features)
image = parsed_example['image/bytes']
image = tf.decode_raw(image,tf.uint8)
x = parsed_example['image/x'] # breadth
y = parsed_example['image/y'] # height
z = parsed_example['image/z'] # depth
image = tf.cast(image,tf.float32)
# how can I reshape image tensor here? tf.reshape throwing some weird errors.
return {'image':image,'x':x,'y':y,'z':z}
dataset = tf.data.TFRecordDataset([output_path])
dataset.map(parse)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
epoch = 1
with tf.Session() as sess:
for _ in range(epoch):
img = next_element.eval()
print(img)
# when I print image, it shows byte code.
# How can I convert it to numpy array and then show image on my jupyter notebook ?
我以前从未使用过任何这些,而且我一直在阅读TFRecords。请回答如何遍历TFrecord的内容并将其显示在Jupyter笔记本上。随时纠正/优化这两段代码。那对我有很大帮助。
答案 0 :(得分:1)
这是您想要的吗?我认为一旦将其转换为numpy数组,您就可以使用PIL.Image在jupyter笔记本中显示。
将tf记录转换为numpy => How can I convert TFRecords into numpy arrays?
将numpy数组显示为图像 https://gist.github.com/kylemcdonald/2f1b9a255993bf9b2629