训练随机森林分类器火花

时间:2016-01-06 07:41:25

标签: scala apache-spark

基本上我已经清理了我的数据集,删除了标题,错误值等等。我现在正在尝试训练一个随机的森林分类器,以便它可以进行预测。我到目前为止:

import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.feature.StandardScaler

object{
   def main(args: Array[String]): Unit = {
    //setting spark context
    val conf = new SparkConf().setAppName("Churn")
    val sc = new SparkContext(conf)
    //loading and mapping data into RDD
    val csv = sc.textFile("file://filename.csv")
    val data = csv.map {line =>
    val parts = line.split(",").map(_.trim)
    val stringvec = Array(parts(1)) ++ parts.slice(4,20)
    val label = parts(20.toDouble)
    val vec = stringvec.map(_.toDouble)
    LabeledPoint(label, Vectors.dense(vec))
    }
    val splits = data.randomSplit(Array(0.7,0.3))
    val(training, testing) = (splits(0),splits(1))
    val model = RandomForest.trainClassifier(training)
    }
}

但是我收到如下错误:

error: overloaded method value trainClassifier with alternatives:

  (input: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint],strategy: org.apache.spark.mllib.tree.configuration.Strategy,numTrees: Int,featureSubsetStrategy: String,seed: Int)org.apache.spark.mllib.tree.model.RandomForestModel
 cannot be applied to (org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint])
   val model = RandomForest.trainClassifier(training)
谷歌搜索它让我无处可去。如果你能解释这个错误是什么以及为什么我得到它,我将不胜感激。然后我可以自己研究解决方案。

2 个答案:

答案 0 :(得分:1)

您没有向RandomForest.trainClassifier()传递足够的参数,没有方法trainClassifier(RDD[LabeledPoint])。有几个重载版本,但您可以在trainClassifier找到简单版本。

您不仅要发送标记点,还要发送Strategy,树木数量,featureSubsetStrategy和种子(int)。

示例如下所示:

RandomForest.trainClassifier(training,
  Strategy.defaultStrategy("Classification"), 
  3, 
  "auto", 
  12345)

在实践中,你会使用比3更多的树和不同的种子。

答案 1 :(得分:0)

Full answer in Github Original Dataset

逐一做这些事 在测试数据中使用两行.csv文件第一行作为标题获取,第二行作为测试数据

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{LabeledPoint, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{RandomForestRegressionModel, 
RandomForestRegressor}


object RandomForest {
   def main(args: Array[String]): Unit = {
     val sparkSess = org.apache.spark.sql.SparkSession.builder().master("local[*]").appName("car_mpg").getOrCreate()
     import sparkSess.implicits._
     val carData = sparkSess.read.format("csv").option("header","true").option("InterScema","true")
  .csv("D:\\testData\\mpg.csv").toDF("mpg","cylinders","displacement","hp","weight","acceleration","model_year","origin","car_name")
  .map(data => LabeledPoint(data(0).toString.toDouble, Vectors.dense(Array(data(1).toString.toDouble,
    data(2).toString.toDouble, data(3).toString.toDouble, data(4).toString.toDouble, data(5).toString.toDouble))))

val carData_df = carData.toDF("label","features")

val featureIndexer = new VectorIndexer()
  .setInputCol("features").setOutputCol("indexedFeatures").fit(carData)

val Array(training) = carData_df.randomSplit(Array(0.7))

val randomReg = new RandomForestRegressor()
    .setLabelCol("label").setFeaturesCol("features")

val model = new Pipeline()
  .setStages(Array(featureIndexer,randomReg)).fit(training)

val testData = sparkSess.read.format("csv").option("header","true").option("InterScema","true")
  .csv("D:\\testData\\testData.csv")
  .toDF("mpg","cylinders","displacement","hp","weight","acceleration","model_year","origin","car_name")
  .map(data => LabeledPoint(data(0).toString.toDouble,
    Vectors.dense(data(1).toString.toDouble,data(2).toString.toDouble,
      data(3).toString.toDouble, data(4).toString.toDouble, data(5).toString.toDouble)))

val predictions = model.transform(testData)
predictions.select("prediction","Label","Features").show()

val rmse = new RegressionEvaluator().setLabelCol("label")
  .setPredictionCol("prediction").setMetricName("rmse").evaluate(predictions)
println("Root Mean Squared Error :\n" + rmse)

val treeModels = model.stages(1).asInstanceOf[RandomForestRegressionModel]
println("Learned Regression tree models :\n" + treeModels.toDebugString)

sparkSess.stop()

} }

enter link description here