使用Tensorflow的Estimator API,我应该在管道的哪一点执行数据增强?
根据该官方Tensorflow guide,input_fn
中有一个执行数据扩充的地方:
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
# augments image using slice, reshape, resize_bilinear
# |
# |
# |
# v
image = _augment_helper(image)
return image, parsed["label"]
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.map(map_func=parse_fn)
# ...
return dataset
如果我在input_fn
内执行数据增强,parse_fn
是否返回单个示例或包含原始输入图像和所有增强变体的批次?如果只返回一个[增强的]示例,如何确保数据集中的所有图像以其非增强形式以及所有变体形式使用?
答案 0 :(得分:0)
如果您在数据集上使用迭代器,则_augment_helper函数将在输入的每个数据块中的数据集的每次迭代中调用(就像您在dataset.map中调用parse_fn一样)
将代码更改为
ds_iter = dataset.make_one_shot_iterator()
ds_iter = ds_iter.get_next()
return ds_iter
我已经使用简单的增强功能对此进行了测试
def _augment_helper(image):
print(image.shape)
image = tf.image.random_brightness(image,255.0, 1)
image = tf.clip_by_value(image, 0.0, 255.0)
return image
将255.0更改为数据集中的最大值,我将255.0用作示例数据集的8位像素值
答案 1 :(得分:0)
每次调用parse_fn时,它将返回单个示例,然后,如果使用.batch()操作,它将返回一批已解析的图像