我有多个具有相同结构的张量流数据集。 我想将它们组合到单个数据集中。使用 tf.dataset.concatenate
但是我发现,在对这个组合数据集进行改组时,该数据集并未在整个数据集的范围内进行改组。但是在每个单独的数据集中都进行了改组。
有什么方法可以解决这个问题?
答案 0 :(得分:3)
将两个Dataset
连接在一起时,将得到第一个的元素,然后是第二个的元素。如果对结果进行混洗,并且混洗缓冲区小于Dataset
的大小,则混合效果不好。
您需要的是从数据集中插入样本。如果使用TF> = 1.9,最好的方法是使用专用的tf.contrib.data.choose_from_datasets
函数。直接来自文档的示例:
datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
tf.data.Dataset.from_tensors("bar").repeat(),
tf.data.Dataset.from_tensors("baz").repeat()]
# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(3).repeat(3)
result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)
如果保留批次中的样品顺序和/或比例很重要,则最好改组输入数据集。
如果您使用的是TF的早期版本,则可以依靠zip
,flat_map
和concatenate
的组合,如下所示:
a = tf.data.Dataset.range(3).repeat()
b = tf.data.Dataset.range(100, 105).repeat()
value = (tf.data.Dataset
.zip((a, b))
.flat_map(lambda x, y: tf.data.Dataset.concatenate(
tf.data.Dataset.from_tensors([x]),
tf.data.Dataset.from_tensors([y])))
.make_one_shot_iterator()
.get_next())
sess = tf.InteractiveSession()
for _ in range(10):
print(value.eval())
答案 1 :(得分:0)
不能100%确定,但是您可能需要研究在数据集对象上调用不同操作的顺序。 shuffle()
的行为可能会因顺序而异。另请参阅可能与之相关的this问题。
答案 2 :(得分:0)
洗牌缓冲区的大小是多少?
例如,如果您有3个数据集,每个数据集包含1000个项目,则需要应用shuffle(3000)来随机化所有项目的顺序。
这里是一个例子:
这应该将3000个项目全部洗掉:
dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
dataset = dataset.shuffle(3000)
但是,这将 不 洗牌整个数据集:
dataset1 = dataset1.shuffle(1000)
dataset2 = dataset2.shuffle(1000)
dataset3 = dataset3.shuffle(1000)
dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
答案 3 :(得分:0)
从tensorflow 1.9开始,您还可以使用sample_from_datasets方法。
例如,以下代码
datasets = [tf.data.Dataset.from_tensors("foo").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
tf.data.Dataset.from_tensors("bar").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
tf.data.Dataset.from_tensors("baz").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat()]
dataset = tf.data.experimental.sample_from_datasets(datasets) # from 1.12
# dataset = tf.contrib.data.sample_from_datasets(datasets) # between 1.9 and 1.12
iterator = dataset.make_one_shot_iterator();next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(next_element))
将打印
(0, b'bar')
(0, b'foo')
(1, b'bar')
(0, b'baz')
(2, b'bar')
(1, b'foo')
(1, b'baz')
(2, b'foo')
(2, b'baz')
(0, b'foo')