从generateLinearRDD对数据集进行逻辑回归失败,并出现java.lang.IllegalArgumentException

时间:2018-07-17 18:04:31

标签: scala apache-spark apache-spark-mllib

因此,作为某种概念的证明,我试图使用LinearDataGenerator.generateLinearRDD中的样本数据生成一个DataFrame,然后对其执行逻辑回归。

假设generateLinearRDD会生成适合执行线性回归的数据,我将其与Binarizer固定在管道中以创建适合于逻辑回归的阈值列。

我的代码如下:

import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.ml.feature.Binarizer
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}

// databricks users can comment out lines between here...
  val spark = SparkSession
    .builder()
    .appName("Java Spark SQL basic example")
    .config("spark.master", "local")
    .getOrCreate()

  import spark.implicits._
// ...and here

  val data = {
    val tmp = LinearDataGenerator.generateLinearRDD(spark.sparkContext, 10000, 4, 0.05).toDF()
    MLUtils.convertVectorColumnsToML(tmp, "features").withColumnRenamed("label", "continuousLabel")
  }

  val binarizer = new Binarizer()
    .setInputCol("continuousLabel")
    .setOutputCol("label")
    .setThreshold(0)

  val logisticRegression = new LogisticRegression()

  val pipeline = new Pipeline()
      .setStages(Array(binarizer, logisticRegression))

   val pipelineModel = pipeline.fit(data)

   println(pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel].binarySummary.accuracy)

异常中的堆栈跟踪如下:

Exception in thread "main" java.lang.IllegalArgumentException
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.spark.util.ClosureCleaner$.getClassReader(ClosureCleaner.scala:46)
    at org.apache.spark.util.FieldAccessFinder$$anon$3$$anonfun$visitMethodInsn$2.apply(ClosureCleaner.scala:449)
    at org.apache.spark.util.FieldAccessFinder$$anon$3$$anonfun$visitMethodInsn$2.apply(ClosureCleaner.scala:432)
    at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:733)
    at scala.collection.mutable.HashMap$$anon$1$$anonfun$foreach$2.apply(HashMap.scala:134)
    at scala.collection.mutable.HashMap$$anon$1$$anonfun$foreach$2.apply(HashMap.scala:134)
    at scala.collection.mutable.HashTable$class.foreachEntry(HashTable.scala:236)
    at scala.collection.mutable.HashMap.foreachEntry(HashMap.scala:40)
    at scala.collection.mutable.HashMap$$anon$1.foreach(HashMap.scala:134)
    at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:732)
    at org.apache.spark.util.FieldAccessFinder$$anon$3.visitMethodInsn(ClosureCleaner.scala:432)
    at org.apache.xbean.asm5.ClassReader.a(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.b(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
    at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean$14.apply(ClosureCleaner.scala:262)
    at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean$14.apply(ClosureCleaner.scala:261)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:261)
    at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:159)
    at org.apache.spark.SparkContext.clean(SparkContext.scala:2299)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2073)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:939)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:938)
    at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap$1.apply(PairRDDFunctions.scala:743)
    at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap$1.apply(PairRDDFunctions.scala:742)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:742)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass$lzycompute(MulticlassMetrics.scala:48)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass(MulticlassMetrics.scala:44)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy$lzycompute(MulticlassMetrics.scala:168)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy(MulticlassMetrics.scala:168)
    at org.apache.spark.ml.classification.LogisticRegressionSummary$class.accuracy(LogisticRegression.scala:1445)
    at org.apache.spark.ml.classification.LogisticRegressionSummaryImpl.accuracy(LogisticRegression.scala:1641)
    at crossvalidation_graphs$.delayedEndpoint$crossvalidation_graphs$1(crossvalidation_graphs.scala:35)
    at crossvalidation_graphs$delayedInit$body.apply(crossvalidation_graphs.scala:9)
    at scala.Function0$class.apply$mcV$sp(Function0.scala:34)
    at scala.runtime.AbstractFunction0.apply$mcV$sp(AbstractFunction0.scala:12)
    at scala.App$$anonfun$main$1.apply(App.scala:76)
    at scala.App$$anonfun$main$1.apply(App.scala:76)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.generic.TraversableForwarder$class.foreach(TraversableForwarder.scala:35)
    at scala.App$class.main(App.scala:76)
    at crossvalidation_graphs$.main(crossvalidation_graphs.scala:9)
    at crossvalidation_graphs.main(crossvalidation_graphs.scala)

我的模式当前如下所示:

root
 |-- continuousLabel: double (nullable = false)
 |-- features: vector (nullable = true)

我正在使用Scala 2.11.12运行Spark 2.3.1

1 个答案:

答案 0 :(得分:0)

类似于this guy,我的实际问题是我使用Java 10而不是Java8。当我切换回Java 8时,我的代码正常工作。