调整Spark DataFrame API Logistic回归模型的拦截

时间:2016-11-14 10:10:28

标签: apache-spark apache-spark-mllib

我在Spark中训练逻辑回归。但是,由于训练数据的细节,我需要手动调整模型,即改变截距。

使用RDD api很容易 - 只是实例化一个新的LogisticRegressionModel:

val intercept = model.intercept() + adjustment
val model = new LogisticRegressionModel(model.weights(), intercept)

但是,DataFrame API中的LogisticRegressionModel构造函数是私有的。如何手动调整模型?

1 个答案:

答案 0 :(得分:0)

今天下午我遇到了同样的问题而且我处于测试模式,无论如何都试图让它发生,所以我不在乎它有多脏:从模型中获取系数,得到截距,调整它,然后手动使用他们在Spark中使用的code进行预测(查找BLAS.dotmarginscore)。在某些时候,他们使用BLAS.dot,而BLAS是私有的。再次执行相同操作,检索dot的代码,处理SparseVector / DenseVector,然后就可以了。很脏,但它有效。