apache spark随机森林回归器“model.transform”在未经训练的数据集

时间:2018-03-23 07:30:12

标签: scala apache-spark

在Spark ML中使用随机森林回归器,同时转换未在训练和测试中使用的新数据集上的数据显示错误。

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

val data_N = spark.read.option("header", "true").csv("C:/Users/admin/Desktop/ChurnData12Mar18.csv")
val data =data_N.selectExpr("CAST (Acc_Name AS string)","CAST (Account_ID AS string)","CAST (Subsc_ID AS string)","CAST (Prod_ID AS int)","CAST (Prod_Name AS string)","CAST (Location_State AS string)","CAST (Year AS int)","CAST (Location_Country AS string)","CAST (Month AS string)","CAST (Active_Subsc AS int)","CAST (Payment_Stat AS string)","CAST (Payment_Method AS string)","CAST (Demographics AS string)","CAST (No_Renewals AS int)","CAST (No_Churn AS int)","CAST (No_Acquisitions AS int)","CAST (Cust_Satisfaction AS int)","CAST (Subsc_Duration AS int)","CAST (Payment_Pending_Since AS int)","CAST (Invoice_Raised_Since AS int)","CAST (Expiry_Date AS int)","CAST (Amount AS int)")
val stringColumns = Array("Account_ID","Month","Payment_Stat","Payment_Method")

val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringColumns.map( cname => new StringIndexer().setInputCol(cname).setHandleInvalid("skip").setOutputCol(s"${cname}_index"))
val index_pipeline = new Pipeline().setStages(index_transformers)
val index_model = index_pipeline.fit(data)
val df_indexed = index_model.transform(data)

val indexColumns  = df_indexed.columns.filter(x => x contains "index")
val one_hot_encoders: Array[org.apache.spark.ml.PipelineStage] = indexColumns.map(cname => new OneHotEncoder().setInputCol(cname).setOutputCol(s"${cname}_vec"))

val one_hot_pipeline = new Pipeline().setStages(one_hot_encoders)
val df_encoded = one_hot_pipeline.fit(df_indexed).transform(df_indexed)

val numFeatNames = Seq("Month_index_vec","Payment_Stat_index_vec","Payment_Method_index_vec","Prod_ID","Year","Cust_Satisfaction","Subsc_Duration","Payment_Pending_Since","Invoice_Raised_Since","Expiry_Date")
val allFeatNames = numFeatNames 
val assembler = new VectorAssembler().setInputCols(Array(allFeatNames: _*))
                                     .setOutputCol("features")
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)

import org.apache.spark.ml.feature.Normalizer
val normalizer = new Normalizer()
  .setInputCol("indexedFeatures")
  .setOutputCol("normFeatures")
  .setP(1.0)

val rf = new RandomForestRegressor()
  .setLabelCol("No_Churn")
  .setFeaturesCol("normFeatures")

val pipeline = new Pipeline()
    .setStages(Array(index_pipeline,one_hot_pipeline, assembler, featureIndexer, normalizer, rf))

val Array(trainingData, testData) = data.filter("Year < 2018").randomSplit(Array(0.6, 0.4))
val model = pipeline.fit(trainingData)

val predictions = model.transform(testData)

predictions.select("prediction", "No_Churn", "features").show(5)

val evaluator = new RegressionEvaluator()
  .setLabelCol("No_Churn")
  .setPredictionCol("prediction")
  .setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println("Root Mean Squared Error (RMSE) on test data = " + rmse)

val predictions = model.transform(data.filter("Year == 2018"))
predictions.count
predictions.createOrReplaceTempView("prediction")
spark.sql("select prediction, No_Churn, features from prediction order by prediction desc").show(200)

输出:

rmse: Double = 0.23059628575927363
Root Mean Squared Error (RMSE) on test data = 0.23059628575927363
predictions: org.apache.spark.sql.DataFrame = [Acc_Name: string, Account_ID: string ... 32 more fields]
res70: Long = 853

错误

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 49.0 failed 1 times, most recent failure: Lost task
 0.0 in stage 49.0 (TID 49, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined
function($anonfun$11: (vector) => vector)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown
 Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
    at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
    at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
    at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
    at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$30.apply(RDD.scala:1422)
    at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$30.apply(RDD.scala:1419)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:796)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:796)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:99)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
    at java.lang.Thread.run(Unknown Source)
   Caused by: java.util.NoSuchElementException: key not found: 2018.0
    at scala.collection.MapLike$class.default(MapLike.scala:228)
    at scala.collection.AbstractMap.default(Map.scala:59)
    at scala.collection.MapLike$class.apply(MapLike.scala:141)
    at scala.collection.AbstractMap.apply(Map.scala:59)
    at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:340)
    at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:318)
    at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:363)
    at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:363)
    ... 20 more

   Driver stacktrace:
     at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
     at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
     at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
     at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
     at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
     at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
     at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
     at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
     at scala.Option.foreach(Option.scala:257)
     at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
     at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
     at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
     at org.apache.spark.SparkContext.runJob(SparkContext.scala:1918)
     at org.apache.spark.SparkContext.runJob(SparkContext.scala:1981)
     at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1025)
     at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
     at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
     at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
     at org.apache.spark.rdd.RDD.reduce(RDD.scala:1007)
     at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1.apply(RDD.scala:1428)
     at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
     at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
     at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
     at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1415)
     at org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:133)
     at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2371)
     at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
     at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2765)
     at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2370)
     at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2377)
     at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2113)
     at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2112)
     at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2795)
     at org.apache.spark.sql.Dataset.head(Dataset.scala:2112)
     at org.apache.spark.sql.Dataset.take(Dataset.scala:2327)
     at org.apache.spark.sql.Dataset.showString(Dataset.scala:248)
     at org.apache.spark.sql.Dataset.show(Dataset.scala:636)
     at org.apache.spark.sql.Dataset.show(Dataset.scala:595)
     ... 46 elided
   Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$11: (vector) => vector)
     at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown
Source)
     at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
     at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
     at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
     at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
     at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
     at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
     at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$30.apply(RDD.scala:1422)
     at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$30.apply(RDD.scala:1419)
     at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:796)
     at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:796)
     at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
     at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
     at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
     at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
     at org.apache.spark.scheduler.Task.run(Task.scala:99)
     at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)
     ... 3 more
   Caused by: java.util.NoSuchElementException: key not found: 2018.0
     at scala.collection.MapLike$class.default(MapLike.scala:228)
     at scala.collection.AbstractMap.default(Map.scala:59)
     at scala.collection.MapLike$class.apply(MapLike.scala:141)
     at scala.collection.AbstractMap.apply(Map.scala:59)
     at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:340)
     at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$10.apply(VectorIndexer.scala:318)
     at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:363)
     at org.apache.spark.ml.feature.VectorIndexerModel$$anonfun$11.apply(VectorIndexer.scala:363)
     ... 20 more

0 个答案:

没有答案