训练具有相同索引的拆分

时间:2018-07-06 06:49:47

标签: python python-3.x pandas scikit-learn cross-validation

我希望具有相同索引的行存在于同一集合中-训练或测试,但不能同时存在。我该如何在sklearn中做到这一点?例如:

df = pd.DataFrame({'A': [1, 1, 1, 2, 2, 3, 4, 4, 5, 6, 6, 6], 'B': random.sample(range(10, 100), 12)})
df.set_index('A', inplace = True)

我想要实现:

具有索引1、3、5、6的火车 带有索引2、4的测试集

如何通过GridSearchCV确保这一点?

1 个答案:

答案 0 :(得分:3)

将它们设置为'group'。 sklearn中的大多数拆分器在其中都支持名为groups的参数,可以将其设置为执行您想要的操作

示例:

您可以使用GroupKFoldGroupShuffleSplit

group_kfold = GroupKFold(n_splits=3)
for train_index, test_index in group_kfold.split(df, groups=df.index):
    print("Train", df.iloc[train_index].index)
    print("Test", df.iloc[test_index].index)

Output: 
('Train', Int64Index([1, 1, 1, 2, 2, 3, 4, 4], dtype='int64', name=u'A'))
('Test', Int64Index([5, 6, 6, 6], dtype='int64', name=u'A'))

('Train', Int64Index([2, 2, 4, 4, 5, 6, 6, 6], dtype='int64', name=u'A'))
('Test', Int64Index([1, 1, 1, 3], dtype='int64', name=u'A'))

('Train', Int64Index([1, 1, 1, 3, 5, 6, 6, 6], dtype='int64', name=u'A'))
('Test', Int64Index([2, 2, 4, 4], dtype='int64', name=u'A'))

您可以看到上次火车测试拆分符合您的要求。所有折叠都将包含训练或测试数据,但不能同时包含两者。