如何在Tensorflow数据集中定义自定义拆分?

时间:2019-03-13 10:40:04

标签: tensorflow tensorflow-datasets

我正在浏览新的Tensorflow数据集库的文档,尝试添加新的数据集。我想为自己的数据集指定自定义数据集拆分,并在文档中说:

  

如果数据集带有预定义的拆分(例如,MNIST具有   训练和测试拆分),将这些拆分保留在DatasetBuilder中。如果   这是您自己的数据,您可以决定自己的拆分,我们建议   使用({TRAIN:80%, VALIDATION: 10%, TEST: 10%)的拆分。用户可以   总是通过tfds.Split.subsplit获得细分。

我正在使用GeneratorBasedBuilder,它是tfds.core.DatasetBuilder的子类,可简化定义数据集的过程。因此,我必须实现方法_split_generators,如下所示:

def _split_generators(self, dl_manager):
    # Download source data
    extracted_path = dl_manager.download_and_extract(...)

    # Specify the splits
    return [
        tfds.core.SplitGenerator(
            name="train",
            num_shards=10,
            gen_kwargs={
                "images_dir_path": os.path.join(extracted_path, "train"),
                "labels": os.path.join(extracted_path, "train_labels.csv"),
            },
        ),
        tfds.core.SplitGenerator(
            name="test",
            num_shards=1,
            gen_kwargs={
                "images_dir_path": os.path.join(extracted_path, "test"),
                "labels": os.path.join(extracted_path, "test_labels.csv"),
            },
        ),
    ]

是否可以通过这种方式指定自定义拆分,还是应该从较低级别的DatasetBuilder类继承?

0 个答案:

没有答案