我试图使用tf.data.Dataset交错两个数据集,但遇到了问题。给出这个简单的例子:
ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()
生成类似...
的输出的0, 1, 2, 3...9
是什么?
看起来datat.interleave()似乎是相关的,但我还没有能够以一种不会产生错误的方式来表达语句。
答案 0 :(得分:18)
MattScarpino在his comment的正确轨道上。您可以使用Dataset.zip()
和Dataset.flat_map()
来展平多元素数据集:
ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)
# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
tf.data.Dataset.from_tensors(x1)))
iter = dataset.make_one_shot_iterator()
val = iter.get_next()
话虽如此,你对使用Dataset.interleave()
的直觉是非常明智的。我们正在研究您可以更轻松地完成此任务的方法。
PS。作为替代方案,如果您更改了Dataset.interleave()
和ds0
的定义方式,可以使用ds1
来解决问题:
dataset = tf.data.Dataset.range(2).interleave(
lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)
答案 1 :(得分:2)
tf.data.experimental.sample_from_datasets方法也可能很有用。
在我的情况下,我必须将真实数据与一些综合数据进行交错处理,因此顺序对我来说不是问题。然后可以很容易地做到如下
dataset = tf.data.experimental.sample_from_datasets([ds0, ds1])
请注意,结果将是不确定的,并且某些项目可以从同一数据集中获取两次,但通常与常规交织非常相似。
这种方法的优点:
weights
参数为每个数据集指定样本的比例(例如,我只想生成一小部分数据,所以我使用了weights=[0.9, 0.1]
)答案 2 :(得分:2)
如果您不介意交错顺序,Pavel 的答案会很有效。如果你确实关心...
适用于任意数量的输入数据集的 mrry 答案的变体:
ds0 = tf.data.Dataset.range(0, 10, 3)
ds1 = tf.data.Dataset.range(1, 10, 3)
ds2 = tf.data.Dataset.range(2, 10, 3)
datasets = (ds0, ds1, ds2)
# Note: `datasets` should be a *tuple*, not a list.
datasets_zipped = tf.data.Dataset.zip(datasets)
# Each element of the dataset will now be a tuple, e.g. (0, 1, 2).
datasets_zipped_tensor = datasets_zipped.map(lambda *args: tf.stack(args))
# Each element will now be a Tensor, e.g. Tensor([0, 1, 2]).
datasets_interleaved = datasets_zipped_tensor.unbatch()
但是,请注意,由于 zip
的工作方式,生成的数据集受限于最短输入数据集的长度。例如,将上面的代码与
datasets = [
tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5]),
tf.data.Dataset.from_tensor_slices([10, 20]),
)
将产生仅包含 [1, 10, 2, 20]
的数据集。
Dataset.interleave
没有这个问题。在某些情况下,您可以将 interleave
用于:
# Note: `datasets` should be a *list*, not a tuple
tf.data.Dataset.from_tensor_slices(datasets).interleave(lambda x: x)
但这似乎不适用于所有类型的数据集; 对您的数据集调用 from_tensor_slices
可能不起作用。
如果选项 2 不起作用,您或许可以在数据集管道的早期阶段使用 interleave
。例如,您可能能够从对预先存在的数据集调用 interleave
更改为对创建单个数据集的文件名调用 interleave
:
filenames = ['foo', 'bar']
filesnames_dataset = tf.data.Dataset.from_tensor_slices(filenames)
def read_dataset(filename): ...
interleaved_dataset = filenames_dataset.interleave(read_dataset)
但是这仅在您的 read_dataset
函数接受 Tensor
参数时有效。
如果其他选项都不适合您,我认为唯一的解决方案是自己实现交错,例如:
element_spec = datasets[0].element_spec
assert all(dataset.element_spec == element_spec for dataset in datasets)
def interleave_generator():
iters_not_exhausted = [iter(dataset) for dataset in datasets]
while iters_not_exhausted:
for dataset_iter in iters_not_exhausted:
try:
x = next(dataset_iter)
except StopIteration:
iters_not_exhausted.remove(dataset_iter)
else:
yield x
datasets_interleaved = tf.data.Dataset.from_generator(
interleave_generator,
output_signature=element_spec,
)