sklearn的预测建模管道

时间:2018-06-22 02:07:16

标签: python machine-learning scikit-learn cross-validation

我已经逐渐从R转到Python进行预测建模。我想知道什么是通过交叉验证进行超参数优化并将经过训练的模型应用于新实例的最佳管道。

在下面,您将看到一个我使用随机森林所做的快速示例。我想知道是否可以,您会从中添加或删除什么?

#import data sets
train_df = pd.read_csv('../input/train.csv') 
test_df = pd.read_csv('../input/test.csv')

#get the predictors only
X_train = train_df.drop(["ID", "target"], axis=1) 
y_train = np.log1p(train_df["target"].values)  

X_test = test_df.drop(["ID"], axis=1)

#grid to do the random search

from sklearn.model_selection import RandomizedSearchCV 

n_estimators = [int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)]
max_features = ['auto', 'sqrt']
max_depth = [int(x) for x in np.linspace(10, 110, num = 11)] 
max_depth.append(None)
min_samples_split = [2, 5, 10]
min_samples_leaf = [1, 2, 4]
bootstrap = [True, False]  

# Create the random grid
random_grid = {'n_estimators': n_estimators,
           'max_features': max_features,
           'max_depth': max_depth,
           'min_samples_split': min_samples_split,
           'min_samples_leaf': min_samples_leaf,
           'bootstrap': bootstrap}

#Create the model to tune
rf = RandomForestRegressor()
rf_random= RandomizedSearchCV(estimator = rf, param_distributions = random_grid, n_iter = 100, cv = 10, verbose=2, random_state=42, n_jobs =10)
#fit the random search model
rf_random.fit(X_train, y_train) 

#get the best estimator
best_random = rf_random.best_estimator_ 

# train again with the best parameters on the whole training data?
best_random.fit(X_train,y_train)

#apply the best predictor to the test set
pred_test_rf = np.expm1(best_random.predict(X_test)) 
  1. .best_estimator_是否使用在网格搜索中找到的最佳参数实例化模型?

  2. 如果是,我是否需要重新训练整个训练数据(如上所述),还是已经被重新训练?

  3. 我想知道这种方法是否可行,或者在python中使用sklearn进行此操作有哪些最佳实践。

1 个答案:

答案 0 :(得分:2)

1)是,它是由best_params_的{​​{1}}发起的估算器

2)不,它已经对整个数据进行了训练,不需要做rf_random

RandomizedSearchCV具有一个参数best_random.fit(X_train,y_train),其默认值为'refit'

True

3)您的方法似乎还可以。这是标准方式。其他因素可能取决于各种因素,例如数据类型,数据大小,使用的算法(估计器),探索可能性的时间等。但这部分最适合https://stats.stackexchange.com