检索Spark Mllib StringIndexer列映射

时间:2017-04-23 19:09:48

标签: scala apache-spark apache-spark-mllib apache-spark-ml

如何从经过训练的Spark MLlib StringIndexerModel获取映射?

val stringIndexer = new StringIndexer()
    .setInputCol("myCol")
    .setOutputCol("myColIdx")
val stringIndexerModel = stringIndexer.fit(data)
val res = stringIndexerModel.transform(data)

上面的代码会在myColIdx中将myCol添加到我的DataFrame映射值到基于值频率的索引。即最常见的值 - > 0,第二最频繁 - >等等...

如何从模型中检索该映射?如果我对模型进行序列化/反序列化,那么映射是否稳定(即我在变换后保证得到相同的结果)?

1 个答案:

答案 0 :(得分:4)

StringIndexerModel使用labels属性公开映射:

stringIndexerModel.labels: Array[String]

其中值对应于连续标签,例如:

val data = Seq("foo", "bar", "foo", "bar", "foobar", "bar").toDF("myCol")

您将关注labels

import org.apache.spark.ml.feature.IndexToString

Array(bar, foo, foobar)

bar索引为0.0,foo为1.0,foobar为2.0。这是模型的属性,在模型为saved时保留。

Pipeline中使用时,您还可以使用IndexToString,它将使用列元数据将索引映射回标签。

indexToString.transform(stringIndexerModel.transform(data)).show
+------+--------+-------------+
| myCol|myColIdx|myColReversed|
+------+--------+-------------+
|   foo|     1.0|          foo|
|   bar|     0.0|          bar|
|   foo|     1.0|          foo|
|   bar|     0.0|          bar|
|foobar|     2.0|       foobar|
|   bar|     0.0|          bar|
+------+--------+-------------+