numpys setdiff1d是否损坏?

时间:2020-04-16 13:44:48

标签: python numpy set-difference

要在我的机器学习项目中选择数据进行训练和验证,我通常使用numpys屏蔽功能。因此,用于选择用于验证和测试数据的索引的典型的重复代码块如下所示:

import numpy as np

validation_split = 0.2

all_idx = np.arange(0,100000)
idxValid = np.random.choice(all_idx, int(validation_split * len(all_idx)))
idxTrain = np.setdiff1d(all_idx, idxValid)

现在,以下内容应始终为真:

len(all_idx) == len(idxValid)+len(idxTrain)

不幸的是,我发现并非总是如此。当我增加从all_idx数组中选择的元素数量时,所得的数字将无法正确累加。这是另一个独立的示例,当我将随机选择的验证索引的数量增加到1000以上时,该示例立即中断:

import numpy as np

all_idx = np.arange(0,100000)
idxValid = np.random.choice(all_idx, 1000)
idxTrain = np.setdiff1d(all_idx, idxValid)

print(len(all_idx), len(idxValid), len(idxTrain))

结果为-> 100000、1000、99005

我很困惑?!请尝试一下。我很高兴理解这一点。

2 个答案:

答案 0 :(得分:1)

idxValid = np.random.choice(all_idx, 10, replace=False)

请注意,您不想在idxValid中重复。为此,您只需要在np.random.choice中插入replace=False

replace boolean, optional
    Whether the sample is with or without replacement

答案 1 :(得分:1)

考虑以下示例:

all_idx = np.arange(0, 100)
print(all_idx)
>>> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]

现在,如果您打印出验证数据集:

idxValid = np.random.choice(all_idx, int(validation_split * len(all_idx)))
print(idxValid)
>>> [31 57 55 45 26 25 55 76 33 69 49 90 46 14 18 30 89 73 47 82]

您实际上可以观察到结果集中存在重复项,因此

len(all_idx) == len(idxValid)+len(idxTrain)

不会导致True

您需要做的是通过传递np.random.choice来确保replace=False进行抽样而不会被废止:

idxValid = np.random.choice(all_idx, int(validation_split * len(all_idx)), replace=False)

现在结果应符合预期:

import numpy as np

validation_split = 0.2

all_idx = np.arange(0, 100)
print(all_idx)

idxValid = np.random.choice(all_idx, int(validation_split * len(all_idx)), replace=False)
print(idxValid)

idxTrain = np.setdiff1d(all_idx, idxValid)
print(idxTrain)

print(len(all_idx) == len(idxValid)+len(idxTrain))

,输出为:

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]

[12 85 96 64 48 21 55 56 80 42 11 92 54 77 49 36 28 31 70 66]

[ 0  1  2  3  4  5  6  7  8  9 10 13 14 15 16 17 18 19 20 22 23 24 25 26
 27 29 30 32 33 34 35 37 38 39 40 41 43 44 45 46 47 50 51 52 53 57 58 59
 60 61 62 63 65 67 68 69 71 72 73 74 75 76 78 79 81 82 83 84 86 87 88 89
 90 91 93 94 95 97 98 99]

True

考虑直接使用train_test_split中的scikit-learn

from sklearn.model_selection import train_test_split


train, test = train_test_split(df, test_size=0.2)