PySpark:OneHotEncoder的输出看起来很奇怪

时间:2018-04-03 14:38:28

标签: apache-spark pyspark apache-spark-mllib one-hot-encoding

Spark文档包含OneHotEncoder的{​​{3}}:

from pyspark.ml.feature import OneHotEncoder, StringIndexer

df = spark.createDataFrame([
    (0, "a"),
    (1, "b"),
    (2, "c"),
    (3, "a"),
    (4, "a"),
    (5, "c")
], ["id", "category"])

stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
model = stringIndexer.fit(df)
indexed = model.transform(df)

encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
encoded.show()

我希望专栏categoryVec看起来像这样:

[0.0, 0.0]
[1.0, 0.0]
[0.0, 1.0]
[0.0, 0.0]
[0.0, 0.0]
[0.0, 1.0]

但是categoryVec实际上是这样的:

(2, [0], [1.0])
    (2, [], [])
(2, [1], [1.0])
(2, [0], [1.0])
(2, [0], [1.0])
(2, [1], [1.0])    

这是什么意思?我应该如何阅读这个输出,这个奇怪的格式背后的原因是什么?

1 个答案:

答案 0 :(得分:4)

这里没什么奇怪的。这些只是SparseVectors其中:

  • 第一个元素是向量的大小
  • 第一个数组[...]是索引列表。
  • 第二个数组是值列表。

未明确列出的指数为0.0。