在一个简单的例子

时间:2015-11-26 15:47:06

标签: scala apache-spark dataframe apache-spark-sql apache-spark-ml

我尝试从RandomForestClassifier包(版本1.5.2)运行实验spark.ml。我使用的数据集来自Spark ML guide中的LogisticRegression示例。

以下是代码:

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.sql.Row

// Prepare training data from a list of (label, features) tuples.
val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")

val rf = new RandomForestClassifier()

val model = rf.fit(training)

这是错误,我得到了:

java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.
    at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:87)
    at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:42)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:48)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:53)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:55)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:57)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:59)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:61)
    at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:63)
    at $iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:65)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:67)
    at $iwC$$iwC$$iwC.<init>(<console>:69)
    at $iwC$$iwC.<init>(<console>:71)
    at $iwC.<init>(<console>:73)
    at <init>(<console>:75)
    at .<init>(<console>:79)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at org.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:1065)
    at org.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1340)
    at org.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:840)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:871)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:819)
    at org.apache.spark.repl.SparkILoop.reallyInterpret$1(SparkILoop.scala:857)
    at org.apache.spark.repl.SparkILoop.interpretStartingWith(SparkILoop.scala:902)
    at org.apache.spark.repl.SparkILoop.command(SparkILoop.scala:814)
    at org.apache.spark.repl.SparkILoop.processLine$1(SparkILoop.scala:657)
    at org.apache.spark.repl.SparkILoop.innerLoop$1(SparkILoop.scala:665)
    at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$loop(SparkILoop.scala:670)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply$mcZ$sp(SparkILoop.scala:997)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
    at org.apache.spark.repl.SparkILoop$$anonfun$org$apache$spark$repl$SparkILoop$$process$1.apply(SparkILoop.scala:945)
    at scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135)
    at org.apache.spark.repl.SparkILoop.org$apache$spark$repl$SparkILoop$$process(SparkILoop.scala:945)
    at org.apache.spark.repl.SparkILoop.process(SparkILoop.scala:1059)
    at org.apache.spark.repl.Main$.main(Main.scala:31)
    at org.apache.spark.repl.Main.main(Main.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:497)
    at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:674)
    at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180)
    at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205)
    at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120)
    at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

当函数尝试计算列"label"中的类数时,会出现问题。

正如您在source code of RandomForestClassifier的第84行所看到的,该函数使用参数DataFrame.schema调用"label"函数。此调用正常并返回org.apache.spark.sql.types.StructField对象。 然后,调用函数org.apache.spark.ml.util.MetadataUtils.getNumClasses。由于它没有返回预期的输出,因此在第87行引发了一个例外。

快速浏览getNumClasses source code后,我认为错误是由于colmun "label"中的数据既不BinaryAttribute也不NominalAttribute但是,我不知道如何解决这个问题。

我的问题:

如何解决此问题?

非常感谢您阅读我的问题和帮助!

1 个答案:

答案 0 :(得分:8)

让我们先修复导入以消除歧义

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.linalg.Vectors

我将使用您使用的相同数据:

val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")

然后创建管道阶段:

val stages = new scala.collection.mutable.ArrayBuffer[PipelineStage]()
  1. 对于分类,重新索引类:
  2.     
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(training)
    
    1. 使用VectorIndexer
    2. 识别分类要素     
      val featuresIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(10).fit(training)
      stages += featuresIndexer
      
      val tmp = featuresIndexer.transform(labelIndexer.transform(training))
      
      1. 了解随机森林
      2.     
        val rf = new RandomForestClassifier().setFeaturesCol(featuresIndexer.getOutputCol).setLabelCol(labelIndexer.getOutputCol)
        
        stages += rf
        val pipeline = new Pipeline().setStages(stages.toArray)
        
        // Fit the Pipeline
        val pipelineModel = pipeline.fit(tmp)
        
        val results = pipelineModel.transform(training)
        
        results.show
        
        //+-----+--------------+---------------+-------------+-----------+----------+
        //|label|      features|indexedFeatures|rawPrediction|probability|prediction|
        //+-----+--------------+---------------+-------------+-----------+----------+
        //|  1.0| [0.0,1.1,0.1]|  [0.0,1.0,2.0]|   [1.0,19.0]|[0.05,0.95]|       1.0|
        //|  0.0|[2.0,1.0,-1.0]|  [1.0,0.0,0.0]|   [17.0,3.0]|[0.85,0.15]|       0.0|
        //|  0.0| [2.0,1.3,1.0]|  [1.0,3.0,3.0]|   [14.0,6.0]|  [0.7,0.3]|       0.0|
        //|  1.0|[0.0,1.2,-0.5]|  [0.0,2.0,1.0]|   [1.0,19.0]|[0.05,0.95]|       1.0|
        //+-----+--------------+---------------+-------------+-----------+----------+
        

        参考文献关于第1步和第2步。对于那些想要了解功能转换器的详细信息的人,我建议您阅读官方文档here