如何解释Spark OneHotEncoder

时间:2017-02-17 10:05:32

标签: python apache-spark pyspark one-hot-encoding

我阅读了Spark docs的OHE条目,

  

单热编码将一列标签索引映射到一列二进制向量,最多只有一个单值。此编码允许期望连续特征(例如Logistic回归)的算法使用分类特征。

但遗憾的是他们没有对OHE结果给出完整的解释。所以运行给定的代码:

from pyspark.ml.feature import OneHotEncoder, StringIndexer

df = sqlContext.createDataFrame([
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
], ["id", "category"])

stringIndexer = StringIndexer(inputCol="category",      outputCol="categoryIndex")
model = stringIndexer.fit(df)
indexed = model.transform(df)

encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
encoded.show()

得到了结果:

   +---+--------+-------------+-------------+
   | id|category|categoryIndex|  categoryVec|
   +---+--------+-------------+-------------+
   |  0|       a|          0.0|(2,[0],[1.0])|
   |  1|       b|          2.0|    (2,[],[])|
   |  2|       c|          1.0|(2,[1],[1.0])|
   |  3|       a|          0.0|(2,[0],[1.0])|
   |  4|       a|          0.0|(2,[0],[1.0])|
   |  5|       c|          1.0|(2,[1],[1.0])|
   +---+--------+-------------+-------------+

我如何解释OHE的结果(最后一栏)?

1 个答案:

答案 0 :(得分:18)

单热编码将categoryIndex中的值转换为二进制向量,其中最多一个值可以是1.由于有三个值,向量的长度为2,映射如下:

0  -> 10
1  -> 01
2  -> 00

(为什么这样的映射?请参阅this question关于删除最后一个类别的单热编码器。)

categoryVec中的值正是这些值,但是以稀疏格式表示。在这种格式中,不打印矢量的零。第一个值(2)显示向量的长度,第二个值是一个列出零个或多个索引的数组,其中找到非零条目。第三个值是另一个数组,它告诉在这些索引处找到哪些数字。 所以(2,[0],[1.0])表示长度为2的向量,位置0为1.0,其他位置为0。

请参阅:https://spark.apache.org/docs/latest/mllib-data-types.html#local-vector