我已按以下方式获取了CelebA数据集,其中包含3个分区
>>> celeba_bldr = tfds.builder('celeb_a')
>>> datasets = celeba_bldr.as_dataset()
>>> datasets.keys()
dict_keys(['test', 'train', 'validation'])
ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']
现在,我想将它们全部合并到一个数据集中。例如,我将需要将火车和验证组合在一起,或者可能将它们全部合并在一起,然后根据我自己的不同主题-不相交标准对其进行拆分。反正有这样做吗?
我在文档https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset
中找不到执行此操作的任何选项。答案 0 :(得分:1)
查看您链接的文档,数据集似乎具有concatenate
方法,因此我认为您可以将联合数据集设为:
ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']
ds = ds_train.concatenate(ds_test).concatenate(ds_valid)
请参阅:https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#concatenate