partial_fit perfom如何与模型持久性结合?

时间:2019-06-21 13:27:18

标签: machine-learning scikit-learn

我正在尝试使用MLPRegressor在scikit-learn中创建回归模型,但是一开始并不是所有训练数据都可用。因此,我尝试使用partial_fit()进行一些增量学习。

通过在训练样本中多次使用partial_fit(),我可以减少损失。但是,在转储/加载模型以进行再次训练之后,无论我运行多少次迭代,损失都保持不变。即使第二训练阶段超出了先前训练的数据,也是如此。

为什么第二次培训课程没有更新?为什么转储/加载引起问题?

通过在可用数据上运行一个fit()实例,我能够减少损失,但是据我所知,fit()partial_fit()是如何相互作用的,听起来像坏习惯。

例如,以下

mlp = load('wv3Model.joblib')
numDataSets = 2
for j in np.arange(numDataSets-1):
    trainName = 'trainSet%d' %(j+1)
    truthName = 'truthSet%d' %(j+1)
    trainSamples = load(trainName)
    trainTruth = load(truthName)

    mlp.fit(trainSamples,trainTruth)
    for i in np.arange(maxIt):
        mlp.partial_fit(trainSamples,trainTruth)

dump(mlp,'wv3Model.joblib')

如前所述工作,即即使第二次运行脚本后,损失也减少了。但是,通过删除mlp.fit(trainSamples,trainTruth),我得到以下输出

Iteration 1, loss = 76758.28391730
Iteration 2, loss = 76625.56216054
Iteration 3, loss = 75856.95921531
Iteration 4, loss = 73252.93073076
Iteration 5, loss = 67519.66342697
Iteration 6, loss = 58193.69161805
Iteration 7, loss = 48917.46644853
Iteration 8, loss = 39812.66552474
Iteration 9, loss = 30634.87573078
Iteration 10, loss = 22985.60612562
Iteration 11, loss = 17407.24146271
Iteration 12, loss = 15207.91405823
Iteration 13, loss = 15207.91405823
Iteration 14, loss = 15207.91405823
Iteration 15, loss = 15207.91405823
Iteration 16, loss = 15207.91405823
Iteration 17, loss = 15207.91405823
Iteration 18, loss = 15207.91405823
Iteration 19, loss = 15207.91405823
Iteration 20, loss = 15207.91405823
Iteration 21, loss = 15207.91405823

具有以下(先前已初始化的)模型

mlp = MLPRegressor(hidden_layer_sizes = hLayers, 
                   max_iter = 1, 
                   batch_size = batchSize,
                   n_iter_no_change=500000,
                   tol=tol,
                   alpha=alpha,
                   verbose=True,
                   warm_start=True)

添加mlp.fit(trainSamples,trainTruth)行是有效的,但是对我来说这没有意义。如果有人能解释为什么一种方法行得通或者为什么另一种方法行不通(或两种都行),那将是很棒的。

附加说明:如果我在这里没有提到的其他错误用法,请随时对此发表评论。我绝不是ML的专家。

0 个答案:

没有答案