我使用基于Spark RDD的API(mllib包)训练了机器学习模型1.5.2说“Mymodel123”,
org.apache.spark.mllib.tree.model.RandomForestModel Mymodel123 = ....;
Mymodel123.save("sparkcontext","path");
现在我正在使用基于Spark Dataset的API(ml包)2.2.0。有没有办法使用基于数据集的API加载模型(Mymodel123)?
org.apache.spark.ml.classification.RandomForestClassificationModel newModel = org.apache.spark.ml.classification.RandomForestClassificationModel.load("sparkcontext","path");
答案 0 :(得分:1)
没有可以执行此操作的公共API,但是您RandomForestModels
包装了旧mllib
API和provide private
methods,可用于将mllib
模型转换为{{1}模型:
ml
所以这不是不可能的。在Java中,您可以直接使用它(Java不尊重包私有修饰符),在Scala中,您必须将适配器代码放在/** Convert a model from the old API */
private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
numClasses: Int,
numFeatures: Int = -1): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
}
包中。