使用基于Spark数据集的ML API时初始化逻辑回归系数?

时间:2017-07-03 19:03:23

标签: apache-spark apache-spark-mllib apache-spark-ml

默认情况下,逻辑回归训练将系数初始化为全零。但是,我想自己初始化系数。这将是有用的,例如,如果先前的训练运行在几次迭代后崩溃 - 我可以简单地用最后一组已知系数重新开始训练。

这是否适用于任何基于数据集/数据框的API,最好是Scala?

查看Spark源代码,似乎方法setInitialModel来初始化模型及其系数but it's unfortunately marked as private

基于RDD的API似乎允许初始化系数:LogisticRegressionWithSGD.run(...)的一个重载接受initialWeights向量。但是,我想使用基于数据集的API而不是基于RDD的API,因为(1)前者支持弹性网络正规化(我无法弄清楚如何使用基于RDD的逻辑回归做弹性网络)和(2)因为the RDD-based API is in maintenance mode

我总是可以尝试使用反射来调用私有的setInitialModel方法,但是如果可能的话我想避免这种情况(也许这甚至无法工作......我也不能这样做告诉我setInitialModel是否被标记为私有。)

1 个答案:

答案 0 :(得分:0)

随意覆盖该方法。是的,您需要将该课程复制到您自己的工作区域。那没关系:不要害怕

当您通过mavensbt构建项目时,您的本地副本将“赢”并遮挡MLlib中的项目。幸运的是,同一个包中的其他类将加阴影。

我多次使用这种方法来覆盖Spark类:实际上你的构建时间也应该很小。