如何在GridSearchCV中提供预定义的k-fold

时间:2018-02-27 12:34:36

标签: scikit-learn cross-validation grid-search

我想使用GridSearchCV进行参数调整和评估超过10个预定义的选定数据折叠(作为数据索引列表的列表)。

有没有人知道如何在scikit中为GridSearchCV提供10个预定义测试折叠列表?

splits=[ [0,10,9,1,2,..] ,[3,5,7,..],[23,4,34,..]] #len(split)=10

greed_search = GridSearchCV(estimator, param_grid=parameters, cv=splits,scoring=scoring, refit=score, error_score=0, n_jobs=n_jobs)

1 个答案:

答案 0 :(得分:1)

我认为您需要对此折叠进行预处理:

new_splits = []

for i in range(len(splits)):
    train = [j for i in splits[:i] + splits[i + 1:] for j in i]
    test = splits[i]
    new_splits.append([train, test])

不仅可以迭代测试部件,还可以训练部件