基本上我已经清理了我的数据集,删除了标题,错误值等等。我现在正在尝试训练一个随机的森林分类器,以便它可以进行预测。我到目前为止:
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.feature.StandardScaler
object{
def main(args: Array[String]): Unit = {
//setting spark context
val conf = new SparkConf().setAppName("Churn")
val sc = new SparkContext(conf)
//loading and mapping data into RDD
val csv = sc.textFile("file://filename.csv")
val data = csv.map {line =>
val parts = line.split(",").map(_.trim)
val stringvec = Array(parts(1)) ++ parts.slice(4,20)
val label = parts(20.toDouble)
val vec = stringvec.map(_.toDouble)
LabeledPoint(label, Vectors.dense(vec))
}
val splits = data.randomSplit(Array(0.7,0.3))
val(training, testing) = (splits(0),splits(1))
val model = RandomForest.trainClassifier(training)
}
}
但是我收到如下错误:
error: overloaded method value trainClassifier with alternatives:
(input: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint],strategy: org.apache.spark.mllib.tree.configuration.Strategy,numTrees: Int,featureSubsetStrategy: String,seed: Int)org.apache.spark.mllib.tree.model.RandomForestModel
cannot be applied to (org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint])
val model = RandomForest.trainClassifier(training)
谷歌搜索它让我无处可去。如果你能解释这个错误是什么以及为什么我得到它,我将不胜感激。然后我可以自己研究解决方案。
答案 0 :(得分:1)
您没有向RandomForest.trainClassifier()
传递足够的参数,没有方法trainClassifier(RDD[LabeledPoint])
。有几个重载版本,但您可以在trainClassifier找到简单版本。
您不仅要发送标记点,还要发送Strategy
,树木数量,featureSubsetStrategy
和种子(int)。
示例如下所示:
RandomForest.trainClassifier(training,
Strategy.defaultStrategy("Classification"),
3,
"auto",
12345)
在实践中,你会使用比3更多的树和不同的种子。
答案 1 :(得分:0)
Full answer in Github Original Dataset
逐一做这些事 在测试数据中使用两行.csv文件第一行作为标题获取,第二行作为测试数据
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{LabeledPoint, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{RandomForestRegressionModel,
RandomForestRegressor}
object RandomForest {
def main(args: Array[String]): Unit = {
val sparkSess = org.apache.spark.sql.SparkSession.builder().master("local[*]").appName("car_mpg").getOrCreate()
import sparkSess.implicits._
val carData = sparkSess.read.format("csv").option("header","true").option("InterScema","true")
.csv("D:\\testData\\mpg.csv").toDF("mpg","cylinders","displacement","hp","weight","acceleration","model_year","origin","car_name")
.map(data => LabeledPoint(data(0).toString.toDouble, Vectors.dense(Array(data(1).toString.toDouble,
data(2).toString.toDouble, data(3).toString.toDouble, data(4).toString.toDouble, data(5).toString.toDouble))))
val carData_df = carData.toDF("label","features")
val featureIndexer = new VectorIndexer()
.setInputCol("features").setOutputCol("indexedFeatures").fit(carData)
val Array(training) = carData_df.randomSplit(Array(0.7))
val randomReg = new RandomForestRegressor()
.setLabelCol("label").setFeaturesCol("features")
val model = new Pipeline()
.setStages(Array(featureIndexer,randomReg)).fit(training)
val testData = sparkSess.read.format("csv").option("header","true").option("InterScema","true")
.csv("D:\\testData\\testData.csv")
.toDF("mpg","cylinders","displacement","hp","weight","acceleration","model_year","origin","car_name")
.map(data => LabeledPoint(data(0).toString.toDouble,
Vectors.dense(data(1).toString.toDouble,data(2).toString.toDouble,
data(3).toString.toDouble, data(4).toString.toDouble, data(5).toString.toDouble)))
val predictions = model.transform(testData)
predictions.select("prediction","Label","Features").show()
val rmse = new RegressionEvaluator().setLabelCol("label")
.setPredictionCol("prediction").setMetricName("rmse").evaluate(predictions)
println("Root Mean Squared Error :\n" + rmse)
val treeModels = model.stages(1).asInstanceOf[RandomForestRegressionModel]
println("Learned Regression tree models :\n" + treeModels.toDebugString)
sparkSess.stop()
} }