如何让Spark中的onehotencoder像Pandas中的onehotencoder一样工作?

时间:2017-03-18 15:22:24

标签: apache-spark pyspark one-hot-encoding

当我在Spark中使用onehotencoder时,我会得到第四列中的结果,这是一个稀疏向量。

// +---+--------+-------------+-------------+
// | id|category|categoryIndex|  categoryVec|
// +---+--------+-------------+-------------+
// |  0|       a|          0.0|(3,[0],[1.0])|
// |  1|       b|          2.0|(3,[2],[1.0])|
// |  2|       c|          1.0|(3,[1],[1.0])|
// |  3|      NA|          3.0|    (3,[],[])|
// |  4|       a|          0.0|(3,[0],[1.0])|
// |  5|       c|          1.0|(3,[1],[1.0])|
// +---+--------+-------------+-------------+

然而,我想要的是为类别生成3列,就像它在熊猫中的工作方式一样。

>>> import pandas as pd
>>> s = pd.Series(list('abca'))
>>> pd.get_dummies(s)
   a  b  c
0  1  0  0
1  0  1  0
2  0  0  1
3  1  0  0

2 个答案:

答案 0 :(得分:10)

Spark的OneHotEncoder创建一个稀疏矢量列。要创建类似于pandas OneHotEncoder的输出列,我们需要为每个类别创建一个单独的列。我们可以通过将udf作为参数传递,在pyspark数据框的withColumn函数的帮助下实现这一点。对于前 -

from pyspark.sql.functions import udf,col
from pyspark.sql.types import IntegerType


df = sqlContext.createDataFrame(sc.parallelize(
        [(0,'a'),(1,'b'),(2,'c'),(3,'d')]), ('col1','col2'))

categories = df.select('col2').distinct().rdd.flatMap(lambda x : x).collect()
categories.sort()
for category in categories:
    function = udf(lambda item: 1 if item == category else 0, IntegerType())
    new_column_name = 'col2'+'_'+category
    df = df.withColumn(new_column_name, function(col('col2')))

print df.show()

输出 -

+----+----+------+------+------+------+                                         
|col1|col2|col2_a|col2_b|col2_c|col2_d|
+----+----+------+------+------+------+
|   0|   a|     1|     0|     0|     0|
|   1|   b|     0|     1|     0|     0|
|   2|   c|     0|     0|     1|     0|
|   3|   d|     0|     0|     0|     1|
+----+----+------+------+------+------+

我希望这会有所帮助。

答案 1 :(得分:0)

不能发表评论,因为我没有声望点,所以请回答这个问题。

这实际上是关于火花管道和变压器的最好的事情之一!我不明白为什么你需要以这种格式获得它。你能详细说明吗?