如何在tensorflow tfrecords中增加数据?

时间:2018-01-19 17:51:05

标签: tensorflow deep-learning tensorflow-datasets tensorflow-estimator tensorflow-slim

我正在使用tfrecords存储我的数据,并使用Dataset API将其作为张量读取,然后我使用Estimator API执行培训。现在,我想对数据集中的每个项目进行在线数据扩充,但经过一段时间的尝试后,我找不到出路的方法。我想要随机翻转,随机旋转和其他操纵器。

我遵循this教程中的说明,使用自定义估算器,这是我的CNN,我不确定数据增强步骤的发生位置。

1 个答案:

答案 0 :(得分:2)

使用TFRecords不会阻止您进行数据扩充。

按照您在评论中链接的tutorial,以下是大致发生的事情:

  • 您可以从TFRecords文件创建数据集,并解析文件以获得imagelabel
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
  • 现在,您可以应用新的预处理功能,以便在培训期间进行一些数据扩充
# Only do it when we are training
if train:
    dataset = dataset.map(train_preprocess)
  • train_preprocess函数可以是这样的:
def train_preprocess(image, label):
    flip_image = tf.image.random_flip_left_right(image)
    # Other transformations...
    return flip_image, label