加载像Mnist这样的自定义数据集(Tensorflow Python)

时间:2020-02-19 08:44:29

标签: python tensorflow dataset tensorflow-datasets mnist

我正在实验https://github.com/astirn/IIC中的聚类模型 (已经尝试与他联系)

与大多数研究论文一样,它使用Mnist数据集。 在这里,他们首先将数据集名称定义为“ mnist”,这足以使张量流从其标准在线数据集中导入mnist。 然后,他使用tensorflow_dataset.load()函数加载数据集

我已经为我的数据集创建了一个tfrecord文件,现在我只需要替换前面脚本中指向“ mnist”(下面代码中的第1行)的部分,而不是指向我的本地数据集即可。

我只用第一行的文件路径替换'mnist'吗?

实际训练模型文件中的代码

if __name__ == '__main__':
# pick a data set
    DATA_SET = 'mnist'

# define splits
    DS_CONFIG = {
        # mnist data set parameters
        'mnist': {
            'batch_size': 700,
            'num_repeats': 5,
            'mdl_input_dims': [24, 24, 1]}
    }

# load the data set
    TRAIN_SET, TEST_SET, SET_INFO = load(data_set_name=DATA_SET, **DS_CONFIG[DATA_SET])

# configure the common model elements
    MDL_CONFIG = {
    # mist hyper-parameters
        'mnist': {
            'num_classes': SET_INFO.features['label'].num_classes,
            'learning_rate': 1e-4,
            'num_repeats': DS_CONFIG[DATA_SET]['num_repeats'],
            'save_dir': None},
    }

“数据准备文件”中的代码,他在其中调用tensorflor_dataset.load作为tfds.load的数据集:

def load(data_set_name, **kwargs):
    """
    :param data_set_name: data set name--call tfds.list_builders() for options
    :return:
        train_ds: TensorFlow Dataset object for the training data
        test_ds: TensorFlow Dataset object for the testing data
        info: data set info object
    """
    # get data and its info
    ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True)

感谢您的帮助

1 个答案:

答案 0 :(得分:0)

根据docs,您需要将download参数用作Falsedata_dir并使用目录名称:

ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True, download=False, data_dir="/path/to/file")