为什么我不能加载PySpark RandomForestClassifier模型?

时间:2016-10-30 08:49:01

标签: apache-spark pyspark apache-spark-mllib

我无法加载Spark保存的RandomForestClassificationModel。

环境:Apache Spark 2.0.1,在小型(4机器)群集上运行的独立模式。没有HDFS - 一切都保存到本地磁盘。

构建并保存模型:

classifier = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=50)
model = classifier.fit(train)
result = model.transform(test)
model.write().save("/tmp/models/20161030-RF-topics-cats.model")

稍后,在另一个程序中:

model = RandomForestClassificationModel.load("/tmp/models/20161029-RF-topics-cats.model")

给出:

Py4JJavaError: An error occurred while calling o81.load.
: org.apache.spark.sql.AnalysisException: Unable to infer schema for ParquetFormat at /tmp/models/20161029-RF-topics-cats.model/treesMetadata. It must be specified manually;
    at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$16.apply(DataSource.scala:411)
    at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$16.apply(DataSource.scala:411)
    at scala.Option.getOrElse(Option.scala:121)
    at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:410)
    at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:149)
    at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:439)
    at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:423)
    at org.apache.spark.ml.tree.EnsembleModelReadWrite$.loadImpl(treeModels.scala:441)
    at org.apache.spark.ml.classification.RandomForestClassificationModel$RandomForestClassificationModelReader.load(RandomForestClassifier.scala:301

我注意到当我使用Naive Bayes分类器时,相同的代码可以正常工作。

1 个答案:

答案 0 :(得分:1)

将模型保存到HDFS,稍后从HDFS读取模型可能会解决您的问题。

您有4个节点,每个节点都有自己的本地磁盘。 您正在使用model.write()。save(“/ temp / xxx”)

稍后,在另一个程序中: 您正在使用load(“/ temp / xxx”)

由于有4个节点,有4个不同的本地磁盘,因此我不清楚在write.save()操作期间究竟保存了什么(以及哪个本地磁盘),以及究竟正在加载什么()和本地磁盘。