在scikit-learn中训练神经网络的早期停止

时间:2014-02-21 08:34:14

标签: python machine-learning neural-network scikit-learn cross-validation

这个问题非常具体到Python库scikit-learn。如果最好将其发布到其他地方,请告诉我。谢谢!

现在的问题......

我有一个基于BaseEstimator的前馈神经网络类ffnn,我用SGD训练。它工作正常,我也可以使用GridSearchCV()并行训练。

现在我想在函数ffnn.fit()中实现提前停止,但为此我还需要访问fold的验证数据。一种方法是更改​​sklearn.grid_search.fit_grid_point()中的行

clf.fit(X_train, y_train, **fit_params)

类似

clf.fit(X_train, y_train, X_test, y_test, **fit_params)

并更改ffnn.fit()以获取这些参数。这也会影响sklearn中的其他分类器,这是一个问题。我可以通过检查fit_grid_point()中的某种标志来避免这种情况,该标志告诉我何时以上述两种方式之一调用clf.fit()。

在没有编辑sklearn库中的任何代码的情况下,有人可以建议使用不同的方法吗?

或者,将X_train和y_train进一步拆分为训练/验证集并检查一个好的停止点,然后在所有X_train上重新训练模型是否正确?

谢谢!

2 个答案:

答案 0 :(得分:7)

您可以让神经网络模型在内部通过使用X_train函数从传递的y_traintrain_test_split中提取验证集。

编辑:

  

或者,将X_train和y_train进一步拆分为训练/验证集并检查一个好的停止点,然后在所有X_train上重新训练模型是否正确?

是的,但那会很贵。您可以找到停止点,然后只需对用于查找停止点的验证数据执行一次额外的传递。

答案 1 :(得分:1)

有两种方法:

<强> 第一:

在进行x_train和x_test分割时。您可以从x_train获取0.1分割并保留以进行验证x_dev:

x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.25)

x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train, test_size=0.1)

clf = GridSearchCV(YourEstimator(), param_grid=param_grid,)
clf.fit(x_train, y_train, x_dev, y_dev)

您的估算工具将如下所示,并使用x_dev,y_dev实现提前停止

class YourEstimator(BaseEstimator, ClassifierMixin):
    def __init__(self, param1, param2):
        # perform initialization
        #

    def fit(self, x, y, x_dev=None, y_dev=None):
        # perform training with early stopping
        #

<强> 第二

您不会在x_train上执行第二次拆分,但会在Estimator的fit方法中取出dev设置

x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.25)

clf = GridSearchCV(YourEstimator(), param_grid=param_grid)
clf.fit(x_train, y_train)

您的估算工具将如下所示:

class YourEstimator(BaseEstimator, ClassifierMixin):
    def __init__(self, param1, param2):
        # perform initialization
        #

    def fit(self, x, y):
        # perform training with early stopping
        x_train, x_dev, y_train, y_dev = train_test_split(x, y, 
                                                         test_size=0.1)