我有一个自定义数据集,然后将其存储为tfrecord,
# toy example data
label = np.asarray([[1,2,3],
[4,5,6]]).reshape(2, 3, -1)
sample = np.stack((label + 200).reshape(2, 3, -1))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def labeled_image_to_tfexample(sample_binary_string, label_binary_string):
return tf.train.Example(features=tf.train.Features(feature={
'sample/image': bytes_feature(sample_binary_string),
'sample/label': bytes_feature(label_binary_string)
}))
def _write_to_tf_record():
with tf.Graph().as_default():
image_placeholder = tf.placeholder(dtype=tf.uint16)
encoded_image = tf.image.encode_png(image_placeholder)
label_placeholder = tf.placeholder(dtype=tf.uint16)
encoded_label = tf.image.encode_png(image_placeholder)
with tf.python_io.TFRecordWriter("./toy.tfrecord") as writer:
with tf.Session() as sess:
feed_dict = {image_placeholder: sample,
label_placeholder: label}
# Encode image and label as binary strings to be written to tf_record
image_string, label_string = sess.run(fetches=(encoded_image, encoded_label),
feed_dict=feed_dict)
# Define structure of what is going to be written
file_structure = labeled_image_to_tfexample(image_string, label_string)
writer.write(file_structure.SerializeToString())
return
但是我看不懂它。首先,我尝试过(基于http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.html,https://medium.com/coinmonks/storage-efficient-tfrecord-for-images-6dc322b81db4和https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564)
def read_tfrecord_low_level():
data_path = "./toy.tfrecord"
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
reader = tf.TFRecordReader()
_, raw_records = reader.read(filename_queue)
decode_protocol = {
'sample/image': tf.FixedLenFeature((), tf.int64),
'sample/label': tf.FixedLenFeature((), tf.int64)
}
enc_example = tf.parse_single_example(raw_records, features=decode_protocol)
recovered_image = enc_example["sample/image"]
recovered_label = enc_example["sample/label"]
return recovered_image, recovered_label
我还尝试了变体方法,例如,将enc_example转换并解码,例如在Unable to read from Tensorflow tfrecord file中。但是,当我尝试评估它们时,我的python会话只会冻结并且不提供任何输出或回溯。
然后我尝试使用热切的执行来查看发生了什么,但是显然它仅与tf.data API兼容。但是据我了解,对tf.data API的转换是在整个数据集上进行的。 https://www.tensorflow.org/api_guides/python/reading_data提到必须编写一个解码函数,但是没有给出如何执行该操作的示例。我发现的所有教程都是针对TFRecordReader制作的(对我不起作用)。
任何帮助(指出我在做什么错/解释正在发生的事情/有关如何使用tf.data API解码tfrecord的指示)都受到高度赞赏。
根据https://www.youtube.com/watch?v=4oNdaQk0Qv4和https://www.youtube.com/watch?v=uIcqeP7MFH0 tf.data是创建输入管道的最佳方法,所以我对学习这种方法非常感兴趣。
谢谢!
答案 0 :(得分:2)
我不确定为什么存储编码的png会导致评估不起作用,但是这是解决此问题的一种可能方法。既然您提到要使用tf.data
创建输入管道的方式,我将在玩具示例中展示如何使用它:
label = np.asarray([[1,2,3],
[4,5,6]]).reshape(2, 3, -1)
sample = np.stack((label + 200).reshape(2, 3, -1))
首先,必须将数据保存到TFRecord文件。与您所做的不同之处在于,该图像未编码为png。
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
writer = tf.python_io.TFRecordWriter("toy.tfrecord")
example = tf.train.Example(features=tf.train.Features(feature={
'label_raw': _bytes_feature(tf.compat.as_bytes(label.tostring())),
'sample_raw': _bytes_feature(tf.compat.as_bytes(sample.tostring()))}))
writer.write(example.SerializeToString())
writer.close()
上面的代码中发生的事情是将数组变成字符串(一维对象),然后存储为字节特征。
然后,使用tf.data.TFRecordDataset
和tf.data.Iterator
类读回数据:
filename = 'toy.tfrecord'
# Create a placeholder that will contain the name of the TFRecord file to use
data_path = tf.placeholder(dtype=tf.string, name="tfrecord_file")
# Create the dataset from the TFRecord file
dataset = tf.data.TFRecordDataset(data_path)
# Use the map function to read every sample from the TFRecord file (_read_from_tfrecord is shown below)
dataset = dataset.map(_read_from_tfrecord)
# Create an iterator object that enables you to access all the samples in the dataset
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
label_tf, sample_tf = iterator.get_next()
# Similarly to tf.Variables, the iterators have to be initialised
iterator_init = iterator.make_initializer(dataset, name="dataset_init")
with tf.Session() as sess:
# Initialise the iterator passing the name of the TFRecord file to the placeholder
sess.run(iterator_init, feed_dict={data_path: filename})
# Obtain the images and labels back
read_label, read_sample = sess.run([label_tf, sample_tf])
函数_read_from_tfrecord()
是:
def _read_from_tfrecord(example_proto):
feature = {
'label_raw': tf.FixedLenFeature([], tf.string),
'sample_raw': tf.FixedLenFeature([], tf.string)
}
features = tf.parse_example([example_proto], features=feature)
# Since the arrays were stored as strings, they are now 1d
label_1d = tf.decode_raw(features['label_raw'], tf.int64)
sample_1d = tf.decode_raw(features['sample_raw'], tf.int64)
# In order to make the arrays in their original shape, they have to be reshaped.
label_restored = tf.reshape(label_1d, tf.stack([2, 3, -1]))
sample_restored = tf.reshape(sample_1d, tf.stack([2, 3, -1]))
return label_restored, sample_restored
除了对形状[2, 3, -1]
进行硬编码之外,您还可以将其也存储到TFRecord文件中,但是为了简单起见,我没有这样做。
我用一个可行的例子做了一点gist。
希望这会有所帮助!