Spark:覆盖库方法

时间:2016-04-17 23:24:55

标签: scala apache-spark machine-learning pyspark data-science

我想对spark.ml.classification.LogisticRegression的scala代码进行一些修改,而不必重建整个Spark。 因为我们可以将jar文件附加到spark-submit或pySpark的执行中。是否可以编译LogisticRegression.java的修改副本并覆盖Spark的默认方法,或者至少创建新的方法?感谢。

1 个答案:

答案 0 :(得分:2)

创建一个扩展org.apache.spark.ml.classification.LogisticRegression的新类,并在不修改源代码的情况下重写相应的方法应该有效。

class CustomLogisticRegression
  extends
    LogisticRegression {
  override def toString(): String = "This is overridden Logistic Regression Class"
}

使用新的CustomLogisticRegression

运行Logistic回归
val data = sqlCtx.createDataFrame(MLUtils.loadLibSVMFile(sc, "/opt/spark/spark-1.5.2-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt"))

val customLR = new CustomLogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.3)
  .setElasticNetParam(0.8)

val customLRModel = customLR.fit(data)

val originalLR = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.3)
  .setElasticNetParam(0.8)

val originalLRModel = originalLR.fit(data)

// Print the intercept for logistic regression
println(s"Custom Class's Intercept: ${customLRModel.intercept}")
println(s"Original Class's Intercept: ${originalLRModel.intercept}")
println(customLR.toString())
println(originalLR.toString())

输出:

Custom Class's Intercept: 0.22456315961250317
Original Class's Intercept: 0.22456315961250317
This is overridden Logistic Regression Class
logreg_1cd811a145d7