如何在火花2.0训练过程中获得损失函数的梯度?

时间:2017-04-09 14:30:04

标签: apache-spark apache-spark-mllib

目前,我们正在研究Spark 2.0,我想知道在火花训练期间损失函数的梯度是如何变化的,可用于可视化训练过程。 例如,我有以下代码:

// Load training data in LIBSVM format.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)

// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS()
  .setNumClasses(10)
  .run(training)

我知道包装内有一些类" org.apache.spark.mllib.evaluation"可以用来从模型中获取一些指标,但我仍然无法知道在训练过程中损失函数的梯度是如何改变的。

有没有解决方案?

1 个答案:

答案 0 :(得分:2)

不幸的是,spark-mllib并不支持这种类型的查询,并且在它被弃用的情况下很快就会支持它。

另一方面,您可以使用LogisticRegressionbinomial系列的spark-ml版本(目前唯一支持的版本)。因此,您可以按如下方式计算损失函数:

scala> import org.apache.spark.ml.classification.LogisticRegression
scala> val training = spark.read.format("libsvm").load("./data/mllib/sample_libsvm_data.txt")
// training: org.apache.spark.sql.DataFrame = [label: double, features: vector]

scala> val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
// lr: org.apache.spark.ml.classification.LogisticRegression = logreg_ea4e7cd94045

scala> val lrModel = lr.fit(training)
// 17/04/10 11:51:19 WARN LogisticRegression: LogisticRegression training finished but the result is not converged because: max iterations reached
// lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_ea4e7cd94045

scala> val trainingSummary = lrModel.summary
trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@2878abcd

scala> val objectiveHistory = trainingSummary.objectiveHistory
// objectiveHistory: Array[Double] = Array(0.6833149135741672, 0.6662875751473734, 0.6217068546034619, 0.6127265245887887, 0.6060347986802872, 0.6031750687571562, 0.5969621534836274, 0.5940743031983119, 0.5906089243339021, 0.589472457649104, 0.5882187775729588)

scala> objectiveHistory.foreach(loss => println(loss))
// 0.6833149135741672
// 0.6662875751473734
// 0.6217068546034619
// 0.6127265245887887
// 0.6060347986802872
// 0.6031750687571562
// 0.5969621534836274
// 0.5940743031983119
// 0.5906089243339021
// 0.589472457649104
// 0.5882187775729588

我希望这会有所帮助。

PS:此解决方案也适用于Spark 1.6。