我正在研究 mnist 数据集,其中包含 1797 张图像,表示 0 到 10 位数字。我想将数据集拆分为训练、验证和测试子数据,以便为每个 sub_data 指定相同数量的每个数字。 python中没有sklearn库如何进行分层?
提前感谢您的回答。
答案 0 :(得分:1)
要进行分层数据拆分,您需要知道每个数据点属于哪个类。如果你有一个数据点列表和一个对应的类列表,你可以提取属于某个类的所有点,并按照输入的比例进行分割。
下面是一些实现这个想法的代码: 请注意,您必须添加一些数组来跟踪数据点在循环中拆分后所属的类。
import numpy as np
train, valid, test = 0.6, 0.2, 0.2
data_points = np.random.rand(1000, 32, 32)
classes = np.random.randint(0, 10, size = (1000,))
class_set = np.unique(classes)
data_train = []
data_valid = []
data_test = []
for class_i in class_set:
data_inds = np.where(classes==class_i)
data_i = data_points[data_inds, ...]
N_i = len(data_inds)
N_i_train = int(N_i*train)
N_i_valid = int(N_i*valid)
data_train.append(data_i[:N_i_train])
data_valid.append(data_i[N_i_train:N_i_train+N_i_valid])
data_test.append(data_i[N_i_train+N_i_valid:])
data_train = np.concatenate(data_train)
data_valid = np.concatenate(data_valid)
data_test = np.concatenate(data_test)