我有一些视频序列,在这些视频序列中,我需要(随机地)提取连续的帧并同时重合数据。我主要遵循@mrry的SO答案(TensorFlow - Read video frames from TFRecords file),以便轻松访问框架。
我的管道当前看起来类似于以下内容,其中我用适当的参数调用get_dataset
。下面也提供了一个用于传递到get_dataset
的示例encode_func。在这种情况下,数据加载确实是一个瓶颈,根据RunOptions中的FULL_TRACE,此过程可能需要5到10秒。增加并行调用似乎也没有太大帮助。
我尝试通过将_decode
分为decode_cache
和decode_post_cache
,然后在前者上使用地图后使用dataset.cache()
来优化此过程。这样做的假设是,需要很长时间的部分将帧图像从字符串解码为浮点型。但是,这意味着我必须为整个数据集将所有帧缓存在内存中(为了有效地在decode_post_cache
映射中获得随机偏移量),这是有问题的,因为大小。
是否有更好的方法来优化此效果?目前,这是速度瓶颈。
def get_dataset(location, decode_func, num_epochs, batch_size,
prefetch_batch_size, decode_parallel_calls):
"""Gets the tf.data.Dataset object.
Args:
location: The location of the TFRecords to ingest, possibly a regex.
decode_func: The function to use to decode the examples.
num_epochs: The number of epochs.
batch_size: The batch size.
prefetch_batch_size: How much to prefetch in buffer.
decode_parallel_calls: How many parallel cpu calls to make to decode.
Returns:
a tf.data.Dataset object.
"""
num_files = len(glob.glob(location))
files = tf.data.Dataset.list_files(location).shuffle(num_files)
dataset = files.interleave(
tf.data.TFRecordDataset, cycle_length=1)
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(decode_func, num_parallel_calls=decode_parallel_calls)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=prefetch_batch_size)
return dataset
def _decode(serialized_example,
num_frames=100,
preprocess_func=None,
is_heldout=False,
fps=25):
"""
Decode the serialized_example into tensors suitable for tf.Data.
Args:
serialized_example: The example decoded from TFRecords.
num_frames: The number of frames to consider.
preprocess_func: How to process the image after decoding it, e.g. image augmentation.
is_heldout: Whether this is a held out set for train validation purposes.
fps: The frames per second.
"""
inputs = _get_decode_inputs(serialized_example, fps)
parsed_features, datum_num_frames, half_num_frames = inputs[:3]
###
# Get data related to labels.
# ...
###
all_frames = tf.range(0, datum_num_frames - num_frames)
if is_heldout:
random_offset = half_num_frames
else:
heldout_range = tf.range(half_num_frames - num_frames,
half_num_frames + num_frames)
heldout_range = tf.reshape(trainval_range, [1, -1])
all_frames = tf.reshape(all_frames, [1, -1])
all_frames = tf.sets.set_difference(all_frames, heldout_range)
all_frames = tf.sparse.to_dense(all_frames)
index = tf.random_uniform(
(),
minval=0,
maxval=tf.size(all_frames),
dtype=tf.int32,
name='RandomIndex')
random_offset = tf.cast(tf.gather(all_frames, index), tf.int64)
###
# Do some light computation related to data requirements related to labels, not the images themselves.
# ...
# labels = ...
# ...
###
offsets = tf.range(random_offset, random_offset + num_frames)
frames = tf.map_fn(
lambda i: _decode_img(parsed_features['frames'].values[i], tf.image.decode_jpeg, preprocess_func),
offsets,
dtype=tf.float32)
return frames, labels
def _decode_img(img, decode_func, preprocess_func):
img = tf.reverse(decode_func(img), axis=[-1])
img = tf.image.convert_image_dtype(img, tf.float32)
if preprocess_func:
img = preprocess_func(img)
return img