本地训练的和Dataproc训练的Spark ML模型之间的不一致

时间:2020-05-27 16:15:45

标签: scala apache-spark google-cloud-dataproc

我正在将Spark从2.3.1版本升级到2.4.5。我正在使用Dataproc映像1.4.27-debian9在Google Cloud Platform的Dataproc上使用Spark 2.4.5重新训练模型。当我使用Spark 2.4.5将Dataproc生成的模型加载到本地计算机上以验证模型时。不幸的是,我收到以下异常:

20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.

加载模型的代码非常简单:

import org.apache.spark.ml.PipelineModel

object ModelLoad {
  def main(args: Array[String]): Unit = {
    val modelInputPath = getClass.getResource("/model.ml").getPath
    val model = PipelineModel.load(modelInputPath)
  }
}

我遵循堆栈跟踪检查了1_gbtc_961a6ef213b2/metadata/part-00000模型元数据文件,并发现了以下内容:

{
    "class": "org.apache.spark.ml.classification.GBTClassificationModel",
    "timestamp": 1590593177604,
    "sparkVersion": "2.4.5",
    "uid": "gbtc_961a6ef213b2",
    "paramMap": {
        "maxIter": 50
    },
    "defaultParamMap": {
        ...
        "impurity": "variance",
        ...
    },
    "numFeatures": 1,
    "numTrees": 50
}

杂质被设置为variance,但是我的本地火花2.4.5希望它是gini。为了进行完整性检查,我在本地spark 2.4.5上对模型进行了重新训练。模型元数据文件中的impurity设置为gini

因此,我在GBT Javadoc中检查了spark 2.4.5 setImpurity method。它说The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance."。 Dataproc使用的spark 2.4.5似乎与Apache Spark文档一致。但是,我从Maven Central使用的Spark 2.4.5将impurity的值设置为gini

有人知道为什么Dataproc中的Spark 2.4.5与Maven Central之间存在这种不一致吗?

我创建了一个简单的培训代码以在本地重现结果:

import java.nio.file.Paths

import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object SimpleModelTraining {
  def main(args: Array[String]) {


    val currentRelativePath = Paths.get("")
    val save_file_location = currentRelativePath.toAbsolutePath.toString

    val spark = SparkSession.builder()
      .config("spark.driver.host", "127.0.0.1")
      .master("local")
      .appName("spark-test")
      .getOrCreate()

    val df: DataFrame = spark.createDataFrame(Seq(
      (0, 0),
      (1, 0),
      (1, 0),
      (0, 1),
      (0, 1),
      (0, 1),
      (0, 2),
      (0, 2),
      (0, 2),
      (0, 3),
      (0, 3),
      (0, 3),
      (1, 4),
      (1, 4),
      (1, 4)
    )).toDF("label", "category")

    val pipeline: Pipeline = new Pipeline().setStages(Array(
      new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
      new GBTClassifier().setMaxIter(30)
    ))

    val pipelineModel: PipelineModel = pipeline.fit(df)
    pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
  }
}

谢谢!

1 个答案:

答案 0 :(得分:1)

Dataproc back-ported中的火花SPARK-25959的一个修复程序,可能导致本地训练的ML模型和Dataproc训练的ML模型之间出现这种不一致。