交叉验证:查找不属于训练集的测试集的行索引

时间:2016-02-06 18:09:13

标签: python numpy matrix machine-learning cross-validation

我需要做的是从numpy矩阵中随机挑选(替换)50行,以便训练线性分离器。

然后,我需要使用我没有选择的行测试线性分隔符。

对于第一部分,A是我的完整数据矩阵,我这样做:

A_train = A[np.random.randint(A.shape[0],size=50),:]

但我目前找不到有效的方法:

A_test = ...

其中A_test不包含与A_train相同的行。我该怎么做?

这个问题的关键是A是一个n x m矩阵,而不是一维矩阵......

1 个答案:

答案 0 :(得分:1)

您可以使用np.setdiff1d查找未包含在训练集中的行索引:

import numpy as np

gen = np.random.RandomState(0)

n_total = 1000
n_train = 800

train_idx = gen.choice(n_total, size=n_train)
test_idx = np.setdiff1d(np.arange(n_total), train_idx)

替换抽样的一个后果是,符合测试集的示例数量将根据训练集中重复示例的数量而有所不同:

print(test_idx.size)
# 439

如果您想确保测试集的大小一致,您可以从不在训练集中的索引集中重新取样:

n_test = 200
test_idx2 = gen.choice(test_idx, size=n_test)

如果您实际上并不关心使用替换进行采样,那么更简单的选择是生成所有索引的随机排列,然后将前N个作为训练示例,其余作为测试示例:

idx = gen.permutation(n_total)
train_idx, test_idx = idx[:n_train], idx[n_train:]

或者您可以使用np.random.shuffle将阵列中的行拖放到位。

我还应该指出,scikit-learn有various convenience methods用于将数据划分为训练和测试集以进行交叉验证。