为什么加载后RandomForestClassificationModel numTrees值发生了变化?

时间:2018-06-07 11:22:21

标签: apache-spark apache-spark-mllib random-forest apache-spark-ml

我有一个预先训练好的RandomForestClassificationModel位于磁盘上,当我尝试将模型文件传输到hadoop文件系统时,我使用以下代码:

    val RFModel_test = RandomForestClassificationModel.load(modelPathLocal)
    println(RFModel_test.uid + "  " + RFModel_test.numFeatures + "  "+ RFModel_test.getNumTrees)
    RFModel_test.write.overwrite().save(modelPath)

    val RFModel_test2 = RandomForestClassificationModel.load(modelPath)

    # modelPathLocal = D://RandomForestModel_D       
    # modelPath = ${hdfs}/spark/model/test/RandomForestModel_D#

然后我收到了最后一个代码的错误消息:

Exception in thread "Thread-0" java.lang.IllegalArgumentException: requirement failed: RandomForestClassificationModel.load expected 20 trees based on metadata but found 4 trees.
at scala.Predef$.require(Predef.scala:224)
at org.apache.spark.ml.classification.RandomForestClassificationModel$RandomForestClassificationModelReader.load(RandomForestClassifier.scala:325)
at org.apache.spark.ml.classification.RandomForestClassificationModel$RandomForestClassificationModelReader.load(RandomForestClassifier.scala:303)
at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:223)
at org.apache.spark.ml.classification.RandomForestClassificationModel$.load(RandomForestClassifier.scala:287)
at dnsCovDetect.RFModelDetect.detect(RFModelDetect.scala:102)
at dnsCovDetect.DnsThread$$anon$1.run(DnsThread.scala:12)

在分析本地磁盘和hdfs上的元数据文件后,我喜欢除时间戳之外的两个不同的地方。它提醒我,当我使用spark 2.0.2时,我训练了模型,现在我使用的是2.3.0。加载模型后,numTrees发生了变化。我正在设置它4棵树,但它现在改为默认值20.我想知道为什么?为什么剂量值改变了?

    local{"class":"org.apache.spark.ml.classification.RandomForestClassificationModel",
"timestamp":1522317004726,"-----------------
sparkVersion":"2.0.2",----------------------
"uid":"rfc_8367e2954dcc",
"paramMap":{"subsamplingRate":1.0,"featuresCol":"features","minInstancesPerNode":1,
"checkpointInterval":10,"impurity":"gini","maxMemoryInMB":256,
"maxDepth":4,"cacheNodeIds":false,
"seed":207336481,"labelCol":"label","minInfoGain":0.0,
"predictionCol":"prediction","featureSubsetStrategy":"auto","rawPredictionCol":"rawPrediction",
"probabilityCol":"probability","maxBins":10},"numFeatures":7,"numClasses":2,"numTrees":4}


hdfs{"class":"org.apache.spark.ml.classification.RandomForestClassificationModel",
"timestamp":1528358664567,--------------------
"sparkVersion":"2.3.0"------------------------
,"uid":"rfc_8367e2954dcc",
"paramMap":{"subsamplingRate":1.0,"featuresCol":"features","minInstancesPerNode":1,
"checkpointInterval":10,"impurity":"gini","maxMemoryInMB":256,
"numTrees":20,-------------------------------
"maxDepth":4,"cacheNodeIds":false,
"seed":207336481,"labelCol":"label","minInfoGain":0.0,
"predictionCol":"prediction","featureSubsetStrategy":"auto","rawPredictionCol":"rawPrediction",
"probabilityCol":"probability","maxBins":10},"numFeatures":7,"numClasses":2,"numTrees":20}

0 个答案:

没有答案