我正在尝试使用Spark的小组built a large amount of random forest models。我的方法是缓存一个大的输入数据文件,根据school_id将其拆分成片段,将单个学校输入文件缓存到内存中,在每个文件上运行模型,然后提取标签和预测。
model_input.cache()
val schools = model_input.select("School_ID").distinct.collect.flatMap(_.toSeq)
val bySchoolArray = schools.map(School_ID => model_input.where($"School_ID" <=> School_ID).cache)
import org.apache.spark.sql.DataFrame
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.{Pipeline, PipelineModel}
def trainModel(df: DataFrame): PipelineModel = {
val rf = new RandomForestClassifier()
//omit some parameters
val pipeline = new Pipeline().setStages(Array(rf))
pipeline.fit(df)
}
val bySchoolArrayModels = bySchoolArray.map(df => trainModel(df))
val preds = (0 to schools.length -1).map(i => bySchoolArrayModels(i).transform(bySchoolArray(i)).select("prediction", "label")
preds.write.format("com.databricks.spark.csv").
option("header","true").
save("predictions/pred"+schools(i))
代码在一个小子集上工作正常,但它需要的时间比我预期的要长。在我看来,每次运行单个模型时,Spark都会读取整个文件,并且需要永远完成所有模型运行。我想知道我是否没有正确地缓存文件或者我编码它的方式出了什么问题。
任何建议都会有用。谢谢!
答案 0 :(得分:3)
rdd的方法是不可变的,所以rdd.cache()返回一个新的rdd。因此,您需要将cachedRdd分配给另一个变量,然后重新使用它。否则你没有使用缓存的rdd。
val cachedModelInput = model_input.cache()
val schools = cachedModelInput.select("School_ID").distinct.collect.flatMap(_.toSeq)
....