如何更好地优化Tensorflow Data管道以获得更好的视频效果?

时间:2019-01-11 17:27:50

标签: tensorflow tensorflow-datasets

我有一些视频序列,在这些视频序列中,我需要(随机地)提取连续的帧并同时重合数据。我主要遵循@mrry的SO答案(TensorFlow - Read video frames from TFRecords file),以便轻松访问框架。

我的管道当前看起来类似于以下内容,其中我用适当的参数调用get_dataset。下面也提供了一个用于传递到get_dataset的示例encode_func。在这种情况下,数据加载确实是一个瓶颈,根据RunOptions中的FULL_TRACE,此过程可能需要5到10秒。增加并行调用似乎也没有太大帮助。

我尝试通过将_decode分为decode_cachedecode_post_cache,然后在前者上使用地图后使用dataset.cache()来优化此过程。这样做的假设是,需要很长时间的部分将帧图像从字符串解码为浮点型。但是,这意味着我必须为整个数据集将所有帧缓存在内存中(为了有效地在decode_post_cache映射中获得随机偏移量),这是有问题的,因为大小。

是否有更好的方法来优化此效果?目前,这是速度瓶颈。

def get_dataset(location, decode_func, num_epochs, batch_size,
                prefetch_batch_size, decode_parallel_calls):
    """Gets the tf.data.Dataset object.

    Args:
      location: The location of the TFRecords to ingest, possibly a regex.
      decode_func: The function to use to decode the examples.
      num_epochs: The number of epochs. 
      batch_size: The batch size.
      prefetch_batch_size: How much to prefetch in buffer.
      decode_parallel_calls: How many parallel cpu calls to make to decode.

    Returns:
      a tf.data.Dataset object.
    """
    num_files = len(glob.glob(location))
    files = tf.data.Dataset.list_files(location).shuffle(num_files)
    dataset = files.interleave(
        tf.data.TFRecordDataset, cycle_length=1)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(decode_func, num_parallel_calls=decode_parallel_calls)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=prefetch_batch_size)
    return dataset


def _decode(serialized_example,
            num_frames=100,
            preprocess_func=None,
            is_heldout=False,
            fps=25):
    """
    Decode the serialized_example into tensors suitable for tf.Data.

    Args:
      serialized_example: The example decoded from TFRecords.
      num_frames: The number of frames to consider.
      preprocess_func: How to process the image after decoding it, e.g. image augmentation.
      is_heldout: Whether this is a held out set for train validation purposes.
      fps: The frames per second.
    """
    inputs = _get_decode_inputs(serialized_example, fps)
    parsed_features, datum_num_frames, half_num_frames = inputs[:3]

    ###
    # Get data related to labels.
    # ...
    ###

    all_frames = tf.range(0, datum_num_frames - num_frames)
    if is_heldout:
        random_offset = half_num_frames
    else:
        heldout_range = tf.range(half_num_frames - num_frames,
                                 half_num_frames + num_frames)
        heldout_range = tf.reshape(trainval_range, [1, -1])
        all_frames = tf.reshape(all_frames, [1, -1])
        all_frames = tf.sets.set_difference(all_frames, heldout_range)
        all_frames = tf.sparse.to_dense(all_frames)

        index = tf.random_uniform(
            (),
            minval=0,
            maxval=tf.size(all_frames),
            dtype=tf.int32,
            name='RandomIndex')
        random_offset = tf.cast(tf.gather(all_frames, index), tf.int64)

    ###
    # Do some light computation related to data requirements related to labels, not the images themselves.
    # ...
    # labels = ...
    # ...
    ###

    offsets = tf.range(random_offset, random_offset + num_frames)
    frames = tf.map_fn(
        lambda i: _decode_img(parsed_features['frames'].values[i], tf.image.decode_jpeg, preprocess_func),
        offsets,
        dtype=tf.float32)

    return frames, labels


def _decode_img(img, decode_func, preprocess_func):
    img = tf.reverse(decode_func(img), axis=[-1])
    img = tf.image.convert_image_dtype(img, tf.float32)
    if preprocess_func:
        img = preprocess_func(img)
    return img

0 个答案:

没有答案