我正在尝试根据Y_train的不同类别拆分(X_train,Y_train)。 X_train由50,000个25 X 25图像组成,Y_train由50,000个二进制分类(0或1)组成。我试图用下面的代码放置数据
def split(X_train, Y_train):
if Y_train == 0:
0_only = []
0_only.append(X_train)
答案 0 :(得分:2)
这可能会做您想要的:
# Find the indices of the samples in Y_train that are zero
idx_zero = np.where(Y_train == 0)[0]
# Get subset of X_train and Y_train where Y_train is zero
X_train_zero = X_train[idx_zero]
Y_train_zero = Y_train[idx_zero]
然后您可以使用np.where(Y_train == 1)[0]
做同样的事情。