在Spark \ PySpark中保存\ load模型的正确方法是什么

时间:2015-03-25 12:03:10

标签: python apache-spark pyspark apache-spark-mllib

我使用PySpark和MLlib使用Spark 1.3.0,我需要保存并加载我的模型。我使用这样的代码(取自官方documentation

from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating

data = sc.textFile("data/mllib/als/test.data")
ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2])))
rank = 10
numIterations = 20
model = ALS.train(ratings, rank, numIterations)
testdata = ratings.map(lambda p: (p[0], p[1]))
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
predictions.collect() # shows me some predictions
model.save(sc, "model0")

# Trying to load saved model and work with it
model0 = MatrixFactorizationModel.load(sc, "model0")
predictions0 = model0.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))

在我尝试使用model0后,我得到了一个很长的回溯,结束于此:

Py4JError: An error occurred while calling o70.predict. Trace:
py4j.Py4JException: Method predict([class org.apache.spark.api.java.JavaRDD]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:333)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
    at py4j.Gateway.invoke(Gateway.java:252)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:207)
    at java.lang.Thread.run(Thread.java:745)

所以我的问题是 - 我做错了吗?据我调试,我的模型存储(本地和HDFS),它们包含许多带有一些数据的文件。我觉得模型保存正确,但可能没有正确加载。我也用Google搜索,但没有发现任何相关内容。

最近在Spark 1.3.0中添加了这个save \ load功能,因此我还有另一个问题 - 在1.3.0版本之前保存\ load模型的推荐方法是什么?我还没有找到任何好方法,至少对于Python来说。我也尝试了Pickle,但遇到了与此处所述相同的问题Save Apache Spark mllib model in python

4 个答案:

答案 0 :(得分:7)

保存模型的一种方法(在Scala中;但在Python中可能类似):

// persist model to HDFS
sc.parallelize(Seq(model), 1).saveAsObjectFile("linReg.model")

然后可以将已保存的模型加载为:

val linRegModel = sc.objectFile[LinearRegressionModel]("linReg.model").first()

另请参阅相关的question

有关详细信息,请参阅(ref

答案 1 :(得分:5)

自2015年3月28日(您的问题上次编辑后的第二天)合并this pull request时,此问题已得到解决。

您只需要从GitHub(git clone git://github.com/apache/spark.git -b branch-1.3)克隆/获取最新版本,然后使用spark/README.md构建它(遵循$ mvn -DskipTests clean package中的说明)。

注意:我在构建Spark时遇到了麻烦,因为Maven很不稳定。我通过使用$ update-alternatives --config mvn并选择具有优先级:150的“路径”来解决该问题,无论这意味着什么。 Explanation here

答案 2 :(得分:2)

我也碰到了这个 - 它看起来像个bug。 我已向spark jira报告。

答案 3 :(得分:1)

使用ML中的管道训练模型,然后使用MLWriter和MLReader保存模型并将其读回。

from pyspark.ml import Pipeline
from pyspark.ml import PipelineModel

pipeTrain.write().overwrite().save(outpath)
model_in = PipelineModel.load(outpath)