在spark中运行ml.PredictionModel时出现类型不匹配错误

时间:2018-04-01 09:52:23

标签: apache-spark apache-spark-sql apache-spark-ml

在训练完所有模型之后,我正在尝试重命名每个模型预测列以唯一地标识数据集内的模型预测。我收到类型不匹配错误,如下所示:

import org.apache.spark.ml.PredictionModel

import org.apache.spark.sql.DataFrame

val models = Seq(("NB", nbModel), ("DT", dtModel), ("RF", rfModel), ("GBM",gbmModel))

其输出如下:

models: Seq[(String, Any)] = List((NB,NaiveBayesModel (uid=nb_699528805899) with 2 classes), (DT,()), (RF,RandomForestClassificationModel (uid=rfc_403e93000cb6) with 10 trees), (GBM,GBTClassificationModel (uid=gbtc_e778e2781d0b) with 20 trees))

def mlData(inputData: DataFrame, responseColumn: String, baseModels:

  Seq[(String, PredictionModel[_, _])]): DataFrame= {

  baseModels.map{ case(name, model) =>

  model.transform(inputData)

  .select("row_id", model.getPredictionCol )

  .withColumnRenamed("prediction", s"${name}_prediction")

  }.reduceLeft((a, b) =>a.join(b, Seq("row_id"), "inner"))

  .join(inputData.select("row_id", responseColumn), Seq("row_id"),

  "inner")

}

其输出如下:

mlData: (inputData: org.apache.spark.sql.DataFrame, responseColumn: String, baseModels: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])]) org.apache.spark.sql.DataFrame

val mlTrainData= mlData(transferData, "value", models).drop("row_id")

我遇到类型不匹配错误,实际上不应该发生

<console>:102: error: type mismatch;  found : Seq[(String, Any)] required: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])]        val mlTrainData= mlData(transferData, "value", models).drop("row_id")

1 个答案:

答案 0 :(得分:1)

根据输出,很明显DT元组中的第二个元素是Unit而不是PredictionModel - 这就是为什么整个对象都是Seq[(_, Any)]和你的代码失败。

由于您没有提供背景信息,因此您不清楚如何到达那里。