如何根据班级拆分机器学习数据?

时间:2019-09-17 14:12:31

标签: python dataset

我正在尝试根据Y_train的不同类别拆分(X_train,Y_train)。 X_train由50,000个25 X 25图像组成,Y_train由50,000个二进制分类(0或1)组成。我试图用下面的代码放置数据

def split(X_train, Y_train):
    if Y_train == 0:
       0_only = []
       0_only.append(X_train)

1 个答案:

答案 0 :(得分:2)

这可能会做您想要的:

# Find the indices of the samples in Y_train that are zero
idx_zero = np.where(Y_train == 0)[0]

# Get subset of X_train and Y_train where Y_train is zero
X_train_zero = X_train[idx_zero]
Y_train_zero = Y_train[idx_zero]

然后您可以使用np.where(Y_train == 1)[0]做同样的事情。