从TFRecords文件

时间:2017-09-08 09:15:55

标签: python image tensorflow

我使用TensorFlow写了两个方法:

  • convert_imgs_to_TFRecords,将./dataset中的所有图片转换为TFRecords文件img.tfrecords

  • read_imgs_from_TFRecords,阅读img.tfrecords,获取image及其信息,包括heightweightchannelname

但是这些图像与他们的名字不匹配。

例如,名为001.jpg的图像和名为002.jpg的B图像将转换为img.tfrecords,但图像获取名称002.jpg,B图像获得001.jpgread_imgs_from_TFRecords之后。

这两种方法如下:

def convert_imgs_to_TFRecords(imgs_dir='./dataset', tfrecords_name='img.tfrecords'):

    img_filenames_list = os.listdir(imgs_dir)

    writer = tf.python_io.TFRecordWriter(tfrecords_name)

    for item in img_filenames_list:
        file_extension = item.split('.')[-1]
        if(file_extension == 'jpg'):
            img_filename = os.path.join('./dataset', item)
            print("writing {0}".format(item))
            img = cv2.imread(img_filename)# uint8 dtype
            rows = img.shape[0]
            cols = img.shape[1]
            channels = img.shape[2]
            example = tf.train.Example(features = tf.train.Features(feature={
                'name': _bytes_feature(item.encode('utf-8')), # str to bytes
                'height': _int64_feature(rows),
                'width': _int64_feature(cols),
                'channel': _int64_feature(channels),
                'img': _bytes_feature(img.tostring())
                }))
            writer.write(example.SerializeToString())

    writer.close()

def read_imgs_from_TFRecords(tfrecords_file='./img.tfrecords'):
    filename_queue = tf.train.string_input_producer(string_tensor=[tfrecords_file], 
                                                num_epochs=None, 
                                                shuffle=False, 
                                                seed=None, 
                                                capacity=32, 
                                                shared_name=None, 
                                                name=None, 
                                                cancel_op=None)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'name': tf.FixedLenFeature([], tf.string),
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'channel': tf.FixedLenFeature([], tf.int64),
        'img': tf.FixedLenFeature([], tf.string)
            })    
    image = tf.decode_raw(features['img'], tf.uint8)
    # normalize
    # normalize_op = tf.cast(image, tf.float32) * (1.0/255) - 0.5

    height = features['height']
    width = features['width']
    channel = features['channel']
    name = features['name']
    print("ready to run session")
    init_op = tf.group(tf.local_variables_initializer(), 
                   tf.global_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(22):
            img = image.eval()
            h, w, c = [height.eval(), width.eval(), channel.eval()]
            title = name.eval()
            title = title.decode()#bytes to str
            img = img.reshape([h, w, c])
            # pil_image = Image.fromarray(img)
            # pil_image.show()
            print('showing ' + title)
            cv2.imwrite(title, img)
        coord.request_stop()
        coord.join(threads)

1 个答案:

答案 0 :(得分:1)

正如Gphilo和Jie.Zhou在评论中所说,我们应该把一个例子的所有部分组合成一个单独的sess.run。 所以,我更正

img = image.eval()
h, w, c = [height.eval(), width.eval(), channel.eval()]
title = name.eval()

img, h, w, c, title = sess.run([image, height, width, channel, name])

这两种方法只是尝试tf.TFRecord,最好在项目中使用Datasets API