因此,作为某种概念的证明,我试图使用LinearDataGenerator.generateLinearRDD中的样本数据生成一个DataFrame,然后对其执行逻辑回归。
假设generateLinearRDD
会生成适合执行线性回归的数据,我将其与Binarizer
固定在管道中以创建适合于逻辑回归的阈值列。
我的代码如下:
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.ml.feature.Binarizer
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
// databricks users can comment out lines between here...
val spark = SparkSession
.builder()
.appName("Java Spark SQL basic example")
.config("spark.master", "local")
.getOrCreate()
import spark.implicits._
// ...and here
val data = {
val tmp = LinearDataGenerator.generateLinearRDD(spark.sparkContext, 10000, 4, 0.05).toDF()
MLUtils.convertVectorColumnsToML(tmp, "features").withColumnRenamed("label", "continuousLabel")
}
val binarizer = new Binarizer()
.setInputCol("continuousLabel")
.setOutputCol("label")
.setThreshold(0)
val logisticRegression = new LogisticRegression()
val pipeline = new Pipeline()
.setStages(Array(binarizer, logisticRegression))
val pipelineModel = pipeline.fit(data)
println(pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel].binarySummary.accuracy)
异常中的堆栈跟踪如下:
Exception in thread "main" java.lang.IllegalArgumentException
at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
at org.apache.spark.util.ClosureCleaner$.getClassReader(ClosureCleaner.scala:46)
at org.apache.spark.util.FieldAccessFinder$$anon$3$$anonfun$visitMethodInsn$2.apply(ClosureCleaner.scala:449)
at org.apache.spark.util.FieldAccessFinder$$anon$3$$anonfun$visitMethodInsn$2.apply(ClosureCleaner.scala:432)
at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:733)
at scala.collection.mutable.HashMap$$anon$1$$anonfun$foreach$2.apply(HashMap.scala:134)
at scala.collection.mutable.HashMap$$anon$1$$anonfun$foreach$2.apply(HashMap.scala:134)
at scala.collection.mutable.HashTable$class.foreachEntry(HashTable.scala:236)
at scala.collection.mutable.HashMap.foreachEntry(HashMap.scala:40)
at scala.collection.mutable.HashMap$$anon$1.foreach(HashMap.scala:134)
at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:732)
at org.apache.spark.util.FieldAccessFinder$$anon$3.visitMethodInsn(ClosureCleaner.scala:432)
at org.apache.xbean.asm5.ClassReader.a(Unknown Source)
at org.apache.xbean.asm5.ClassReader.b(Unknown Source)
at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean$14.apply(ClosureCleaner.scala:262)
at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean$14.apply(ClosureCleaner.scala:261)
at scala.collection.immutable.List.foreach(List.scala:392)
at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:261)
at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:159)
at org.apache.spark.SparkContext.clean(SparkContext.scala:2299)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2073)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:939)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.collect(RDD.scala:938)
at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap$1.apply(PairRDDFunctions.scala:743)
at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap$1.apply(PairRDDFunctions.scala:742)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:742)
at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass$lzycompute(MulticlassMetrics.scala:48)
at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass(MulticlassMetrics.scala:44)
at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy$lzycompute(MulticlassMetrics.scala:168)
at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy(MulticlassMetrics.scala:168)
at org.apache.spark.ml.classification.LogisticRegressionSummary$class.accuracy(LogisticRegression.scala:1445)
at org.apache.spark.ml.classification.LogisticRegressionSummaryImpl.accuracy(LogisticRegression.scala:1641)
at crossvalidation_graphs$.delayedEndpoint$crossvalidation_graphs$1(crossvalidation_graphs.scala:35)
at crossvalidation_graphs$delayedInit$body.apply(crossvalidation_graphs.scala:9)
at scala.Function0$class.apply$mcV$sp(Function0.scala:34)
at scala.runtime.AbstractFunction0.apply$mcV$sp(AbstractFunction0.scala:12)
at scala.App$$anonfun$main$1.apply(App.scala:76)
at scala.App$$anonfun$main$1.apply(App.scala:76)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.generic.TraversableForwarder$class.foreach(TraversableForwarder.scala:35)
at scala.App$class.main(App.scala:76)
at crossvalidation_graphs$.main(crossvalidation_graphs.scala:9)
at crossvalidation_graphs.main(crossvalidation_graphs.scala)
我的模式当前如下所示:
root
|-- continuousLabel: double (nullable = false)
|-- features: vector (nullable = true)
我正在使用Scala 2.11.12运行Spark 2.3.1