我的模型使用每个输入批次中按时间顺序排列的序列。因此,我将在重新整理输入数据之前创建批次。这就带来了一个问题,批处理始终在整个数据集中包含相同的数据样本(从相同的索引开始-偏移batch_size
),我通过缓存初始数据集并从跳过的数据集中采样来解决了这个问题,但是内存很快(尽管我的数据集只有150MB):
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.window(size=window_size, shift=window_shift, stride=window_stride, drop_remainder=True).flat_map(lambda x: x.batch(window_size))
dataset = dataset.map(process_fn, num_parallel_calls=8)
dataset = dataset.cache()
datasets = []
for i in range(0, batch_size):
d = dataset.skip(i)
d = d.batch(batch_size, drop_remainder=True)
datasets.append(d)
dataset = tf.data.experimental.sample_from_datasets(datasets)
dataset = dataset.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
dataset = dataset.repeat()
还有另一种方法可以实现这种行为吗?我想涵盖批次中第一个序列开始的所有可能的索引。
答案 0 :(得分:0)
您正在消耗内存,因为您要对整批产品进行改组-跳过也可能不是很有效。由于您的数据似乎是内存中的全部数据,因此您可以直接在python中对数据进行采样,而不必太担心性能:
def make_batch(start_idx):
batch = np.empty((batch_size, window_size), dtype=data.dtype)
for batch_idx, data_idx in enumerate(
range(start_idx, start_idx + window_shift * batch_size, window_shift)):
batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
return batch
dataset = (tf.data.Dataset
.range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
.map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
.repeat()
.prefetch(1)) # you might want to consider prefetching for performance
改组现在发生在索引上,而不是整个批次上,因此内存占用量要低得多。