vowpal-wabbit:使用多次传球,保持,&保持期以避免过度拟合?

时间:2017-04-04 14:35:25

标签: machine-learning neural-network supervised-learning vowpalwabbit

我想使用awesome vw --binary --nn 4 train.vw -f category.model 使用以下命令训练二进制S形前馈网络进行类别分类:

vw --binary -t -i category.model -p test.vw

测试一下:

Number of Training Passes

但是我的结果非常糟糕(与我的线性svm估算器相比)。

我发现了一条评论,我应该使用--passes arg参数(holdout_period)。

所以我的问题是如何知道培训通行证的数量以便不再获得再培训模型

P.S。我应该使用df = pd.read_clipboard(header=None) print df.corr() 参数吗?如何?

1 个答案:

答案 0 :(得分:3)

问题中的测试命令不正确。它没有输入(-p ...表示输出预测)。此外,还不清楚您是否要测试预测,因为它显示test但使用的命令有-p ...

测试表示您已标记数据,并且您正在评估模型的质量。严格来说:预测意味着您没有标签,因此您实际上无法知道您的预测有多好。实际上,您还可以预测标记数据,通过忽略它们来假装它没有标签,然后评估这些预测的好坏,因为您实际上有标签。

一般而言:

  • 如果您想进行二进制分类,则应使用{-1, 1}中的标签并使用--loss_function logistic--binary这是一个独立的选项,这意味着你希望预测是二进制的(给你更少的信息)。

  • 如果你已经有一个单独的标签测试集,你就不需要坚持。

vw中的保持机制旨在替换测试集并避免过度拟合,仅在使用多次传递时才相关,因为在单次传递中所有示例均为有效地坚持;每个下一个(但未见过的)示例被视为1)未标记用于预测,并且2)标记用于测试和模型更新。 IOW:你的火车套装实际上也是你的测试装置。

所以你可以在没有坚持的火车上进行多次传球

 vw --loss_function logistic --nn 4 -c --passes 2 --holdout_off train.vw -f model

然后使用单独的标记测试集测试模型:

 vw -t -i model test.vw

或在同一列车上进行多次通过,并将一些保留作为测试集

vw --loss_function logistic --nn 4 -c --passes 20 --holdout_period 7 train.vw -f model

如果您没有测试集,并且希望通过使用多次传递更加强大,那么您可以要求vw坚持每个N示例(默认N为10,但您可以使用--holdout_period <N>显式覆盖它,如上所示)。在这种情况下,您可以指定更多的通过次数,因为vw会在保留集的损失开始增长时自动提前终止。

你注意到你提前终止,因为vw会打印出类似的内容:

passes used = 5
...
average loss = 0.06074 h

表示在提前停止之前实际只使用了N个通道中的5个,并且示例的保留子集上的错误为0.06074(尾随h表示这是保持不变)。

如您所见,传递次数 holdout-period 是完全独立的选项。

为了提高模型的可信度,您可以使用其他优化,更改holdout_period,尝试其他--nn args。您可能还需要检查vw-hypersearch实用程序(在utl子目录中)以帮助查找更好的超参数。

以下是在源代码中包含的其中一个测试集上使用vw-hypersearch的示例:

$ vw-hypersearch 1 20 vw --loss_function logistic --nn % -c --passes 20 --holdout_period 11 test/train-sets/rcv1_small.dat --binary
trying 13 ............. 0.133333 (best)
trying 8 ............. 0.122222 (best)
trying 5 ............. 0.088889 (best)
trying 3 ............. 0.111111
trying 6 ............. 0.1
trying 4 ............. 0.088889 (best)
loss(4) == loss(5): 0.088889
5       0.08888

指示45应该是--nn的良好参数,导致11个示例中1的保留子集丢失0.08888。