使用TensorFlow Dataset API改组/扩充数据

时间:2018-11-04 07:26:31

标签: python tensorflow dataset shuffle data-augmentation

在过去的几天里,我一直在尝试使自己熟悉TensorFlow Dataset API。 我的目标是建立一个将ImageNet馈入模型的工作流程。

我已经实现了一些有效的代码,并且这样做了,它会读取TFRecord文件的N个分片。 但是,我很难做到这一点。有几个问题...

  1. 我想对分片序列进行混洗,以便对整个数据集进行混洗。我一直在尝试使用https://github.com/tensorflow/tensorflow/issues/14857中介绍的方法。 但是,当我尝试使用 tf.data.Dataset.list_files(filenames)时,出现以下错误
  

OP_REQUIRES在example_parsing_ops.cc:144处失败:无效的参数:无法解析示例输入,值:'val / validation-00033-of-00128'

  1. 为了减少延迟,我尝试使用 num_parallel_calls 。但是,这似乎没有使它更快(我目前正在使用i7 8700K)。另外,当我尝试增加预取或随机播放的参数时,似乎会使它减慢很多。难道我做错了什么?还是我使用Dataset API的顺序错误?

  2. 我希望构建它以便可以用于训练模型,并且我想在 decode_for_train 函数中实现数据增强,该功能将是 decode_for_eval的副本功能。但是,我不知道如何。我考虑过的一种方法是添加一条仅使[224,224,3]到[?,224,224,3]的行。但是,我担心这会使运行迭代器的输出从[224,224,3]变为[?,224,224,3]。有这样做的聪明方法吗?

谢谢您的帮助!


当前,我的数据集已转换为TFRecords:

val / validation-00000-of-00128到val / validation-00128-of-00128

我当前的代码是这样的。

def decode_for_eval(example):
    image = image_ops.decode_jpeg(example, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.central_crop(image, central_fraction=0.875)
    image = tf.expand_dims(image, 0)
    image = tf.image.resize_bilinear(image, [224, 224], align_corners=False)
    image = tf.squeeze(image, [0])
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    return image

def decode(serialized_example):
    features = tf.parse_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),
            'image/format': tf.FixedLenFeature([], tf.string, default_value='jpeg'),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=-1),
            'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
        }
    )
    width = tf.cast(features['image/width'], tf.int32)
    height = tf.cast(features['image/height'], tf.int32)

    ### EVAL ###
    image = tf.map_fn(lambda x: decode_for_eval(x), features['image/encoded'], dtype=tf.float32)

    ### TRAIN ###
    # I want to build this!

    label = tf.cast(features['image/class/label'], tf.int32)

    return image, label

def get_iterator(filenames, batch_size):
    with tf.name_scope('input'):
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.repeat(None)
        dataset = dataset.shuffle(5000)
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(decode, num_parallel_calls=4)
        dataset = dataset.prefetch(batch_size)
        iterator = dataset.make_initializable_iterator()
    return iterator

def main():
    sess = tf.Session()

    # get iterator for single file
    dataset_dir = 'val/validation-*'
    filenames = glob.glob(dataset_dir)
    iterator = get_iterator(filenames, 128)
    sess.run(iterator.initializer)
    image, label = iterator.get_next()

    images, labels = sess.run([image, label])

0 个答案:

没有答案