Pyspark Dataframe One-Hot编码

时间:2017-07-04 11:02:22

标签: apache-spark pyspark apache-spark-sql apache-spark-mllib one-hot-encoding

我正在使用分类数据在Spark DataFrame上进行数据准备。我需要对分类数据进行One-Hot-Encoding,我在spark 1.6

上尝试了这个
sqlContext = SQLContext(sc)
df = sqlContext.createDataFrame([
    (0, "a"),
    (1, "b"),
    (2, "c"),
    (3, "a"),
    (4, "a"),
    (5, "c")
], ["id", "category"])

stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
model = stringIndexer.fit(df)
indexed = model.transform(df)
encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
encoded.select("id", "categoryVec").show()

这段代码产生了这种格式的单热编码数据。

+---+-------------+
| id|  categoryVec|
+---+-------------+
|  0|(3,[0],[1.0])|
|  1|(3,[2],[1.0])|
|  2|(3,[1],[1.0])|
|  3|(3,[0],[1.0])|
|  4|(3,[0],[1.0])|
|  5|(3,[1],[1.0])|
+---+-------------+

通常,我对One-Hot编码技术的期望是每个类别的每列和0,1个相应的值。如何从中获取这类数据?

0 个答案:

没有答案