在Spark中将数据集转换为LabeledPoints

时间:2018-08-30 14:06:59

标签: scala apache-spark

我有Iris数据集.csv文件,并使用

将其加载到Scala中。
val data = spark.read
  .format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
  .load(file)
  .cache()

然后,我想获取使用DecisionTree模型所需的LAbeledPoints数据集,每个LabeledPoints是一个元组(标签,要素) 这是我的做法

//Group the features in and array
val groupedData = data.select(array($"petal_length", $"petal_width",
                  $"sepal_length", $"sepal_width") as "features",
                  $"species" as "label")
                  //Make the labels into doubles
                 .withColumn("label", when($"label".equalTo("versicolor"), 1.0)
                 .otherwise(when($"label".equalTo("virginica"), 2.0)
                 .otherwise(3.0)))
                 // Map each row to a LebeledPoint
                 .map(r => {new LabeledPoint(r.getAs[Double]("label"),
                                       r.getAs[Vector]("features"))})

但是当我去查看Schema和我得到的前几条记录时

模式:

groupedData.printSchema()

    root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)

前10条记录

    groupedData.show(10)

    18/08/30 10:54:23 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 3)
java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to org.apache.spark.ml.linalg.Vector
    at spark.SparkApplication$$anonfun$1.apply(SparkApplication.scala:40)
    at spark.SparkApplication$$anonfun$1.apply(SparkApplication.scala:40)
    at ...

所以我的问题是:我做错了吗?这是正确的方法吗?对我来说这是一个学习练习,我对Scala和Spark都是陌生的。

PS:不是讲英语的人,如果不清楚,请问

0 个答案:

没有答案