没有 scikit-learn 的分层训练/验证/测试拆分

时间:2021-05-04 08:28:08

标签: python scikit-learn split

我正在研究 mnist 数据集,其中包含 1797 张图像,表示 0 到 10 位数字。我想将数据集拆分为训练、验证和测试子数据,以便为每个 sub_data 指定相同数量的每个数字。 python中没有sklearn库如何进行分层?

提前感谢您的回答。

1 个答案:

答案 0 :(得分:1)

要进行分层数据拆分,您需要知道每个数据点属于哪个类。如果你有一个数据点列表和一个对应的类列表,你可以提取属于某个类的所有点,并按照输入的比例进行分割。

下面是一些实现这个想法的代码: 请注意,您必须添加一些数组来跟踪数据点在循环中拆分后所属的类。

import numpy as np
train, valid, test = 0.6, 0.2, 0.2
data_points = np.random.rand(1000, 32, 32)
classes     = np.random.randint(0, 10, size = (1000,))
class_set   = np.unique(classes)
data_train  = []
data_valid  = []
data_test   = []
for class_i in class_set:
    data_inds    = np.where(classes==class_i)
    data_i       = data_points[data_inds, ...]
    N_i          = len(data_inds)
    N_i_train    = int(N_i*train)
    N_i_valid    = int(N_i*valid)
    data_train.append(data_i[:N_i_train])
    data_valid.append(data_i[N_i_train:N_i_train+N_i_valid])
    data_test.append(data_i[N_i_train+N_i_valid:])
    
data_train = np.concatenate(data_train)
data_valid = np.concatenate(data_valid)
data_test = np.concatenate(data_test)