交错tf.data.Datasets

时间:2017-11-17 04:19:11

标签: tensorflow

我试图使用tf.data.Dataset交错两个数据集,但遇到了问题。给出这个简单的例子:

ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()

生成类似...的输出的0, 1, 2, 3...9是什么?

看起来datat.interleave()似乎是相关的,但我还没有能够以一种不会产生错误的方式来表达语句。

3 个答案:

答案 0 :(得分:18)

MattScarpino在his comment的正确轨道上。您可以使用Dataset.zip()Dataset.flat_map()来展平多元素数据集:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
    lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
        tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()

话虽如此,你对使用Dataset.interleave()的直觉是非常明智的。我们正在研究您可以更轻松地完成此任务的方法。

PS。作为替代方案,如果您更改了Dataset.interleave()ds0的定义方式,可以使用ds1来解决问题:

dataset = tf.data.Dataset.range(2).interleave(
    lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)

答案 1 :(得分:2)

如果您不需要保留要插入的项目的严格顺序,则

tf.data.experimental.sample_from_datasets方法也可能很有用。

在我的情况下,我必须将真实数据与一些综合数据进行交错处理,因此顺序对我来说不是问题。然后可以很容易地做到如下

dataset = tf.data.experimental.sample_from_datasets([ds0, ds1])

请注意,结果将是不确定的,并且某些项目可以从同一数据集中获取两次,但通常与常规交织非常相似。

这种方法的优点:

  • 您可以在一个方法调用中使用多个数据集
  • 您可以使用weights参数为每个数据集指定样本的比例(例如,我只想生成一小部分数据,所以我使用了weights=[0.9, 0.1]

答案 2 :(得分:2)

如果您不介意交错顺序,Pavel 的答案会很有效。如果你确实关心...

选项 1

适用于任意数量的输入数据集的 mrry 答案的变体:

ds0 = tf.data.Dataset.range(0, 10, 3)
ds1 = tf.data.Dataset.range(1, 10, 3)
ds2 = tf.data.Dataset.range(2, 10, 3)
datasets = (ds0, ds1, ds2)

# Note: `datasets` should be a *tuple*, not a list.
datasets_zipped = tf.data.Dataset.zip(datasets)
# Each element of the dataset will now be a tuple, e.g. (0, 1, 2).
datasets_zipped_tensor = datasets_zipped.map(lambda *args: tf.stack(args))
# Each element will now be a Tensor, e.g. Tensor([0, 1, 2]).
datasets_interleaved = datasets_zipped_tensor.unbatch()

但是,请注意,由于 zip 的工作方式,生成的数据集受限于最短输入数据集的长度。例如,将上面的代码与

一起使用
datasets = [
    tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5]),
    tf.data.Dataset.from_tensor_slices([10, 20]),
)

将产生仅包含 [1, 10, 2, 20] 的数据集。

选项 2

Dataset.interleave 没有这个问题。在某些情况下,您可以将 interleave 用于:

# Note: `datasets` should be a *list*, not a tuple
tf.data.Dataset.from_tensor_slices(datasets).interleave(lambda x: x)

但这似乎不适用于所有类型的数据集; 对您的数据集调用 from_tensor_slices 可能不起作用

选项 3

如果选项 2 不起作用,您或许可以在数据集管道的早期阶段使用 interleave。例如,您可能能够从对预先存在的数据集调用 interleave 更改为对创建单个数据集的文件名调用 interleave

filenames = ['foo', 'bar']
filesnames_dataset = tf.data.Dataset.from_tensor_slices(filenames)

def read_dataset(filename): ...

interleaved_dataset = filenames_dataset.interleave(read_dataset)

但是这仅在您的 read_dataset 函数接受 Tensor 参数时有效

选项 4

如果其他选项都不适合您,我认为唯一的解决方案是自己实现交错,例如:

element_spec = datasets[0].element_spec
assert all(dataset.element_spec == element_spec for dataset in datasets)

def interleave_generator():
  iters_not_exhausted = [iter(dataset) for dataset in datasets]
  while iters_not_exhausted:
    for dataset_iter in iters_not_exhausted:
      try:
        x = next(dataset_iter)
      except StopIteration:
        iters_not_exhausted.remove(dataset_iter)
      else:
        yield x

datasets_interleaved = tf.data.Dataset.from_generator(
    interleave_generator,
    output_signature=element_spec,
)