是否有一个示例如何为input_fn
构建图像分类模型所需的tf.contrib.learn.Estimator
?我的图像存储在多个TFRecords文件中。
使用tf.contrib.learn.read_batch_record_features
,我可以生成批量的编码图像字符串。但是,我没有看到将这些字符串转换为图像的简单方法。
答案 0 :(得分:2)
参考here,您可以对mnist
和fashion-mnist
中存储的train.tfrecords
和test.tfrecords
数据集使用以下内容。
转换为tfrecords
由代码here完成,您需要使用解析器来取回原始图像和标签。
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, label
在使用解析器之后,其余部分很简单,您只需要调用TFRecordDataset(train_filenames)
然后将解析器函数映射到每个元素,这样您就可以获得图像和标签作为输出。
# Keep list of filenames, so you can input directory of tfrecords easily
training_filenames = ["data/train.tfrecords"]
test_filenames = ["data/test.tfrecords"]
# Define the input function for training
def train_input_fn():
# Import MNIST data
dataset = tf.contrib.data.TFRecordDataset(train_filenames)
# Map the parser over dataset, and batch results by up to batch_size
dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels