将我自己的数据集转换为Cifar10格式(X_train,y_train),(X_test,y_test)

时间:2020-05-07 02:46:48

标签: python

我有一个由两个文件夹组成的数据集,每个文件夹都包含图像,我想将我的数据集转换为Cifar10数据集,以便在我在Github中找到它的代码上使用它, 像这样:

(X_train, y_train), (X_test, y_test), label_names = load_cifar10

请帮助!

1 个答案:

答案 0 :(得分:0)

ImageLoader类: “”“不分批地将图像加载到数组中。”“”

def __init__(self, train_dir, test_dir):
    """Create class."""
    self.train_dir = train_dir
    self.test_dir = test_dir

def load_data(self):
    """Load the data."""
    features, labels = [], []

    for source in [self.train_dir, self.test_dir]:
        input, output = [], []
        for class_name in os.listdir(source):
            if os.path.isdir(class_name):
                for img_name in os.listdir(class_name):
                    img = cv2.imread(os.path.join(self.train_dir, class_name, img_name))

                    # ...
                    # Modify your image array here.
                    # ...

                    input.append(img)
                    output.append(class_name)  # or other method to convert label

        # Shuffle labels.
        combine = list(zip(input, output))  # zip as list for Python 3
        np.random.shuffle(combine)
        input, output = zip(*combine)
        features.append(input)
        labels.append(output)

    return [[np.array(features[0], dtype=np.float32),
             np.array(labels[0], dtype=np.float32)],
            [np.array(features[1], dtype=np.float32),
             np.array(labels[1], dtype=np.float32)]]

cifar10 = ImageLoader('训练路径','测试路径') (trainX,trainY),(testX,testY)= cifar10.load_data()