了解scikit-learn GridSearchCV - 参数调整和平均性能指标

时间:2018-02-18 14:04:19

标签: machine-learning scikit-learn grid-search

我试图了解scikit-learn中的GridSearchCV如何在机器学习中实现列车验证测试原则。正如您在下面的代码中看到的那样,我理解它的作用如下:

  1. 将'数据集'拆分为75%和25%,其中75%用于参数调整,25%是保留测试集(第1行)
  2. 初始化一些要搜索的参数(第3到6行)
  3. 将模型拟合到75%的数据集上,但将此数据集拆分为5倍,即每个时间列上60%的数据,测试其他15%,并执行5次(第8-10行) )。 我在这里有第一个和第二个问题,见下文。
  4. 采用性能最佳的模型和参数,测试保持数据(第11-13行)
  5. 问题1 :关于参数空间,步骤3中究竟发生了什么? GridSearchCV是否在五次运行中的每一次尝试每个参数组合(5次),因此总共运行10次? (即'optmizers','init'和'批次'中的单个参数与'epoches'中的2个配对]

    问题2 :'cross_val_score'行打印的分数是多少?这是5次运行中每一次数据单次折叠的10次以上运行的平均值吗? (即整个数据集的平均值的15%)?

    问题3 :假设第5行现在只有1个参数值,这次GridSearchCV实际上并没有搜索任何参数,因为每个参数只有1个值,这是不正确的?

    问题4 :如果在问题3中解释,如果我们对GridSearchCV运行和坚持运行的5倍计算得分的加权平均值,则给出了平均绩效分数在整个数据集上 - 这非常类似于6倍交叉验证实验(即没有网格搜索),除了6倍不完全相同的大小。或者这不是吗?

    非常感谢任何回复!

    X_train_data, X_test_data, y_train, y_test = \
             train_test_split(dataset[:,0:8], dataset[:,8],
                              test_size=0.25,
                              random_state=42) #line 1
    
    model = KerasClassifier(build_fn=create_model, verbose=0)
    optimizers = ['adam']  #line 3
    init = ['uniform']
    epochs = [10,20] #line 5
    batches = [5]   # line 6
    param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, init=init)
    grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)  # line 8
    grid_result = grid.fit(X_train_data, y_train) 
    cross_val_score(grid.best_estimator_, X_train_data, y_train, cv=5).mean() #line 10
    best_param_ann = grid.best_params_      #line 11
    best_estimator = grid.best_estimator_
    heldout_predictions = best_estimator.predict(X_test_data)   #line 13
    

1 个答案:

答案 0 :(得分:1)

问题1: 如您所说,您的数据集将分为5个部分。 将尝试每个参数(在您的情况下为2)。对于每个参数,模型将在5个折叠中的4个进行训练。剩下的一个将用作测试。所以你是对的,在你的例子中,你要训练10次模型。

问题2: 'cross_val_score'是5个测试折叠的平均值(准确度,损失或其他)。这样做是为了避免例如获得良好的结果,因为测试集非常简单。

问题3: 是。如果您只有一组参数来尝试进行网格搜索

,这没有任何意义

问题4: 我并不完全理解你的问题。通常,您在火车上使用网格搜索。这允许您将测试集保存为验证集。如果没有交叉验证,您可以找到一个完美的设置,以最大限度地提高测试集的结果,并且您将过度拟合测试集。通过交叉验证,您可以使用精细调整参数尽可能多地播放,因为您不使用验证集进行设置。

在您的代码中,由于您没有很多参数可供使用,因此不需要CV,但如果您开始添加正则化,则可以尝试10+,在这种情况下,需要CV。< / p>

我希望它有所帮助,