tensorflow-tf.data.Dataset在批处理之前随机跳过样本以获取不同的批次

时间:2018-11-23 19:14:55

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我的模型使用每个输入批次中按时间顺序排列的序列。因此,我将在重新整理输入数据之前创建批次。这就带来了一个问题,批处理始终在整个数据集中包含相同的数据样本(从相同的索引开始-偏移batch_size),我通过缓存初始数据集并从跳过的数据集中采样来解决了这个问题,但是内存很快(尽管我的数据集只有150MB):

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.window(size=window_size, shift=window_shift, stride=window_stride, drop_remainder=True).flat_map(lambda x: x.batch(window_size))
dataset = dataset.map(process_fn, num_parallel_calls=8)
dataset = dataset.cache()
datasets = []
for i in range(0, batch_size):
    d = dataset.skip(i)
    d = d.batch(batch_size, drop_remainder=True)
    datasets.append(d)
dataset = tf.data.experimental.sample_from_datasets(datasets)
dataset = dataset.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
dataset = dataset.repeat()

还有另一种方法可以实现这种行为吗?我想涵盖批次中第一个序列开始的所有可能的索引。

1 个答案:

答案 0 :(得分:0)

您正在消耗内存,因为您要对整批产品进行改组-跳过也可能不是很有效。由于您的数据似乎是内存中的全部数据,因此您可以直接在python中对数据进行采样,而不必太担心性能:

def make_batch(start_idx):
  batch = np.empty((batch_size, window_size), dtype=data.dtype)
  for batch_idx, data_idx in enumerate(
      range(start_idx, start_idx + window_shift * batch_size, window_shift)):
    batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
  return batch

dataset = (tf.data.Dataset
  .range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
  .shuffle(buffer_size=30000, reshuffle_each_iteration=False)
  .map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
  .repeat()
  .prefetch(1)) # you might want to consider prefetching for performance

改组现在发生在索引上,而不是整个批次上,因此内存占用量要低得多。