如何在Tensorflow Estimator的input_fn中执行数据扩充

时间:2018-07-30 12:30:18

标签: tensorflow machine-learning training-data data-augmentation

使用Tensorflow的Estimator API,我应该在管道的哪一点执行数据增强?

根据该官方Tensorflow guideinput_fn中有一个执行数据扩充的地方:

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])

  # augments image using slice, reshape, resize_bilinear
  #         |
  #         |
  #         |
  #         v
  image = _augment_helper(image)

  return image, parsed["label"]

def input_fn():
  files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
  dataset = files.interleave(tf.data.TFRecordDataset)
  dataset = dataset.map(map_func=parse_fn)
  # ...
  return dataset

我的问题

如果我在input_fn内执行数据增强,parse_fn是否返回单个示例或包含原始输入图像和所有增强变体的批次?如果只返回一个[增强的]示例,如何确保数据集中的所有图像以其非增强形式以及所有变体形式使用?

2 个答案:

答案 0 :(得分:0)

如果您在数据集上使用迭代器,则_augment_helper函数将在输入的每个数据块中的数据集的每次迭代中调用(就像您在dataset.map中调用parse_fn一样)

将代码更改为

  ds_iter = dataset.make_one_shot_iterator()
  ds_iter = ds_iter.get_next()
  return ds_iter

我已经使用简单的增强功能对此进行了测试

  def _augment_helper(image):
       print(image.shape)
       image = tf.image.random_brightness(image,255.0, 1)
       image = tf.clip_by_value(image, 0.0, 255.0)
       return image

将255.0更改为数据集中的最大值,我将255.0用作示例数据集的8位像素值

答案 1 :(得分:0)

每次调用parse_fn时,它将返回单个示例,然后,如果使用.batch()操作,它将返回一批已解析的图像