我使用TensorFlow写了两个方法:
convert_imgs_to_TFRecords
,将./dataset
中的所有图片转换为TFRecords文件img.tfrecords
read_imgs_from_TFRecords
,阅读img.tfrecords
,获取image
及其信息,包括height
,weight
,channel
和name
。
但是这些图像与他们的名字不匹配。
例如,名为001.jpg
的图像和名为002.jpg
的B图像将转换为img.tfrecords
,但图像获取名称002.jpg
,B图像获得001.jpg
在read_imgs_from_TFRecords
之后。
这两种方法如下:
def convert_imgs_to_TFRecords(imgs_dir='./dataset', tfrecords_name='img.tfrecords'):
img_filenames_list = os.listdir(imgs_dir)
writer = tf.python_io.TFRecordWriter(tfrecords_name)
for item in img_filenames_list:
file_extension = item.split('.')[-1]
if(file_extension == 'jpg'):
img_filename = os.path.join('./dataset', item)
print("writing {0}".format(item))
img = cv2.imread(img_filename)# uint8 dtype
rows = img.shape[0]
cols = img.shape[1]
channels = img.shape[2]
example = tf.train.Example(features = tf.train.Features(feature={
'name': _bytes_feature(item.encode('utf-8')), # str to bytes
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'channel': _int64_feature(channels),
'img': _bytes_feature(img.tostring())
}))
writer.write(example.SerializeToString())
writer.close()
和
def read_imgs_from_TFRecords(tfrecords_file='./img.tfrecords'):
filename_queue = tf.train.string_input_producer(string_tensor=[tfrecords_file],
num_epochs=None,
shuffle=False,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'name': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channel': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['img'], tf.uint8)
# normalize
# normalize_op = tf.cast(image, tf.float32) * (1.0/255) - 0.5
height = features['height']
width = features['width']
channel = features['channel']
name = features['name']
print("ready to run session")
init_op = tf.group(tf.local_variables_initializer(),
tf.global_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(22):
img = image.eval()
h, w, c = [height.eval(), width.eval(), channel.eval()]
title = name.eval()
title = title.decode()#bytes to str
img = img.reshape([h, w, c])
# pil_image = Image.fromarray(img)
# pil_image.show()
print('showing ' + title)
cv2.imwrite(title, img)
coord.request_stop()
coord.join(threads)
答案 0 :(得分:1)
正如Gphilo和Jie.Zhou在评论中所说,我们应该把一个例子的所有部分组合成一个单独的sess.run。 所以,我更正
img = image.eval()
h, w, c = [height.eval(), width.eval(), channel.eval()]
title = name.eval()
到
img, h, w, c, title = sess.run([image, height, width, channel, name])
这两种方法只是尝试tf.TFRecord,最好在项目中使用Datasets API。