如何使用存储在TFRecords文件中的图像为Estimator构建input_fn

时间:2017-03-07 22:43:24

标签: python tensorflow

是否有一个示例如何为input_fn构建图像分类模型所需的tf.contrib.learn.Estimator?我的图像存储在多个TFRecords文件中。

使用tf.contrib.learn.read_batch_record_features,我可以生成批量的编码图像字符串。但是,我没有看到将这些字符串转换为图像的简单方法。

1 个答案:

答案 0 :(得分:2)

参考here,您可以对mnistfashion-mnist中存储的train.tfrecordstest.tfrecords数据集使用以下内容。

转换为tfrecords由代码here完成,您需要使用解析器来取回原始图像和标签。

def parser(serialized_example):
  """Parses a single tf.Example into image and label tensors."""
  features = tf.parse_single_example(
      serialized_example,
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
      })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image.set_shape([28 * 28])

  # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
  image = tf.cast(image, tf.float32) / 255 - 0.5
  label = tf.cast(features['label'], tf.int32)
  return image, label

在使用解析器之后,其余部分很简单,您只需要调用TFRecordDataset(train_filenames)然后将解析器函数映射到每个元素,这样您就可以获得图像和标签作为输出。

# Keep list of filenames, so you can input directory of tfrecords easily
training_filenames = ["data/train.tfrecords"]
test_filenames = ["data/test.tfrecords"]

# Define the input function for training
def train_input_fn():
  # Import MNIST data
  dataset = tf.contrib.data.TFRecordDataset(train_filenames)

  # Map the parser over dataset, and batch results by up to batch_size
  dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size)
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  iterator = dataset.make_one_shot_iterator()

  features, labels = iterator.get_next()

  return features, labels