RandomOversampler的管道,RandomForestClassifier& GridSearchCV

时间:2018-01-26 12:32:32

标签: python scikit-learn random-forest grid-search oversampling

我正在研究文本二进制分类问题。由于类非常不平衡,我必须采用RandomOversampler()等采样技术。然后为了分类,我将使用RandomForestClassifier(),其参数需要使用GridSearchCV()进行调整。 我正在尝试创建一个管道来按顺序执行这些操作但到目前为止失败了。它会抛出“无效参数”。

param_grid = {
             'n_estimators': [5, 10, 15, 20],
             'max_depth': [2, 5, 7, 9]
         }
grid_pipe = make_pipeline(RandomOverSampler(),RandomForestClassifier())
grid_searcher = GridSearchCV(grid_pipe,param_grid,cv=10)
grid_searcher.fit(tfidf_train[predictors],tfidf_train[target])

2 个答案:

答案 0 :(得分:2)

您在params中定义的参数适用于RandomForestClassifier,但在gridSearchCV中,您没有传递RandomForestClassifier个对象。

您正在传递管道对象,您必须为其重命名参数以访问内部RandomForestClassifier对象。

将它们更改为:

param_grid = {
             'randomforestclassifier__n_estimators': [5, 10, 15, 20],
             'randomforestclassifier__max_depth': [2, 5, 7, 9]
             }

它会起作用。

答案 1 :(得分:1)

感谢A2A。理想情况下,参数定义如下:

  1. 为要在数据上应用的变压器创建管道

pipeline = make_pipeline([('variable initialization 1',transformers1()),('variable initialization 2',transformers2()),]

注意:在关闭方括号之前,请不要忘记在管道后面加上','

eg:pipeline = make_pipeline([('random_over_sampler',RandomOverSampler()),('RandomForestClassifier', RandomForestClassifier()),]

  1. 创建参数网格
param_grid = {'transformations/algorithm'__'parameter_in_transformations/algorithm':[parameters]}

eg: param_grid = {RandomOverSampler__sampling_strategy:['auto']}