如何洗牌串联的Tensorflow数据集

时间:2018-08-09 10:43:37

标签: tensorflow dataset

我有多个具有相同结构的张量流数据集。 我想将它们组合到单个数据集中。使用 tf.dataset.concatenate

但是我发现,在对这个组合数据集进行改组时,该数据集并未在整个数据集的范围内进行改组。但是在每个单独的数据集中都进行了改组。

有什么方法可以解决这个问题?

4 个答案:

答案 0 :(得分:3)

将两个Dataset连接在一起时,将得到第一个的元素,然后是第二个的元素。如果对结果进行混洗,并且混洗缓冲区小于Dataset的大小,则混合效果不好。

您需要的是从数据集中插入样本。如果使用TF> = 1.9,最好的方法是使用专用的tf.contrib.data.choose_from_datasets函数。直接来自文档的示例:

datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
            tf.data.Dataset.from_tensors("bar").repeat(),
            tf.data.Dataset.from_tensors("baz").repeat()]

# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(3).repeat(3)

result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

如果保留批次中的样品顺序和/或比例很重要,则最好改组输入数据集。

如果您使用的是TF的早期版本,则可以依靠zipflat_mapconcatenate的组合,如下所示:

a = tf.data.Dataset.range(3).repeat()
b = tf.data.Dataset.range(100, 105).repeat()

value = (tf.data.Dataset
  .zip((a, b))
  .flat_map(lambda x, y: tf.data.Dataset.concatenate(
    tf.data.Dataset.from_tensors([x]),
    tf.data.Dataset.from_tensors([y])))
  .make_one_shot_iterator()
  .get_next())

sess = tf.InteractiveSession()

for _ in range(10):
  print(value.eval())

答案 1 :(得分:0)

不能100%确定,但是您可能需要研究在数据集对象上调用不同操作的顺序。 shuffle()的行为可能会因顺序而异。另请参阅可能与之相关的this问题。

答案 2 :(得分:0)

洗牌缓冲区的大小是多少?

例如,如果您有3个数据集,每个数据集包含1000个项目,则需要应用shuffle(3000)来随机化所有项目的顺序。

这里是一个例子:

这应该将3000个项目全部洗掉:

dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
dataset = dataset.shuffle(3000)

但是,这将 洗牌整个数据集:

dataset1 = dataset1.shuffle(1000)
dataset2 = dataset2.shuffle(1000)
dataset3 = dataset3.shuffle(1000)
dataset = dataset1.concatenate(dataset2).concatenate(dataset3)

答案 3 :(得分:0)

从tensorflow 1.9开始,您还可以使用sample_from_datasets方法。

例如,以下代码

datasets = [tf.data.Dataset.from_tensors("foo").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
        tf.data.Dataset.from_tensors("bar").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
        tf.data.Dataset.from_tensors("baz").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat()]

dataset = tf.data.experimental.sample_from_datasets(datasets) # from 1.12
# dataset = tf.contrib.data.sample_from_datasets(datasets) # between 1.9 and 1.12

iterator = dataset.make_one_shot_iterator();next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(next_element))

将打印

(0, b'bar')
(0, b'foo')
(1, b'bar')
(0, b'baz')
(2, b'bar')
(1, b'foo')
(1, b'baz')
(2, b'foo')
(2, b'baz')
(0, b'foo')