我使用以下代码构建了一个随机森林模型:
import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("features")
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val training = labelIndexer.transform(df)
val model = rf.fit(training)
现在我想保存模型,以便稍后使用以下代码进行预测:
val predictions: DataFrame = model.transform(testData)
我查看了Spark文档here,但没有找到任何选项。任何的想法? 我花了几个小时来建造这个模型,如果Spark破碎了我就不能把它拿回来。
答案 0 :(得分:2)
可以使用Spark 1.6使用saveAsObjectFile()为基于管道的模型和基本模型在HDFS中保存和重新加载基于树的模型。 以下是基于管道的模型的示例。
// model
val model = pipeline.fit(trainingData)
// Create rdd using Seq
sc.parallelize(Seq(model), 1).saveAsObjectFile("hdfs://filepath")
// Reload model by using it's class
// You can get class of object using object.getClass()
val sameModel = sc.objectFile[PipelineModel]("filepath").first()
答案 1 :(得分:2)
对于RandomForestClassifier save&负载模型:测试火花1.6.2 + scala in ml(在spark 2.0中你可以直接保存模型选项)
import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier //imports
val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
val model = classifier.fit(trainingData)
sc.parallelize(Seq(model), 1).saveAsObjectFile(modelSavePath)
//保存模型
val linRegModel = sc.objectFile[RandomForestClassificationModel](modelSavePath).first() //load model
`val predictions1 = linRegModel.transform(testData)` //predictions1 is dataframe
答案 2 :(得分:1)
它位于MLWriter
界面中 - 可通过模型上的writer
属性访问:
model.asInstanceOf[MLWritable].write.save(path)
这是界面:
abstract class MLWriter extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
/**
* Saves the ML instances to the input path.
*/
@Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit = {
这是早期版本mllib
/ spark.ml
更新模型似乎不可写:
线程“main”中的异常java.lang.UnsupportedOperationException: 管道写入将在此管道上失败,因为它包含一个阶段 它没有实现Writable。不可写的阶段: 类型为rfc_4e467607406f org.apache.spark.ml.classification.RandomForestClassificationModel
因此可能没有直接的解决方案。
答案 3 :(得分:1)
这是与上述Scala saveAsObjectFile()答案相对应的PySpark v1.6实现。
它强制Python对象与Java对象之间来回转换,以使用saveAsObjectFile()实现序列化。
没有Java强制性,我在序列化过程中遇到了奇怪的Py4J错误。如果有人实施起来比较简单,请编辑或评论。
保存训练有素的RandomForestClassificationModel对象:
# Save RandomForestClassificationModel to hdfs
gateway = sc._gateway
java_list = gateway.jvm.java.util.ArrayList()
java_list.add(rfModel._java_obj)
modelRdd = sc._jsc.parallelize(java_list)
modelRdd.saveAsObjectFile("hdfs:///some/path/rfModel")
加载训练有素的RandomForestClassificationModel对象:
# Load RandomForestClassificationModel from hdfs
rfObjectFileLoaded = sc._jsc.objectFile("hdfs:///some/path/rfModel")
rfModelLoaded_JavaObject = rfObjectFileLoaded.first()
rfModelLoaded = RandomForestClassificationModel(rfModelLoaded_JavaObject)
predictions = rfModelLoaded.transform(test_input_df)