如何合并两个(或多个)TensorFlow数据集?

时间:2019-06-11 15:05:28

标签: python tensorflow tensorflow-datasets tensorflow2.0

我已按以下方式获取了CelebA数据集,其中包含3个分区

>>> celeba_bldr = tfds.builder('celeb_a')
>>> datasets = celeba_bldr.as_dataset()
>>> datasets.keys()
dict_keys(['test', 'train', 'validation'])

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

现在,我想将它们全部合并到一个数据集中。例如,我将需要将火车和验证组合在一起,或者可能将它们全部合并在一起,然后根据我自己的不同主题-不相交标准对其进行拆分。反正有这样做吗?

我在文档https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset

中找不到执行此操作的任何选项。

1 个答案:

答案 0 :(得分:1)

查看您链接的文档,数据集似乎具有concatenate方法,因此我认为您可以将联合数据集设为:

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

ds = ds_train.concatenate(ds_test).concatenate(ds_valid)

请参阅:https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#concatenate