pyspark中的存储桶和一种热编码

时间:2019-05-13 09:56:29

标签: python pyspark pyspark-sql

我有一个由以下几列组成的PySpark数据框:

id        Age
1         30
2         25
3         21

我有以下年龄段:[20, 24, 27, 30]

我的预期结果:

id    Age    age_bucket     age_27_30     age_24_27   age_20_24
1     30      (27-30]           1            0           0
2     25      (24-27]           0            1           0
3     21      (20-24]           0            0           1

我当前的代码:

from pyspark.ml.feature import Bucketizer
bucketizer = Bucketizer(splits=[ 20,24,27,30 ],inputCol="Age", outputCol="age_bucket")
df1 = bucketizer.setHandleInvalid("keep").transform(df)

2 个答案:

答案 0 :(得分:0)

使用OneHotEncoderEstimator()

spark.version
'2.4.3'

df = spark.createDataFrame([(1, 30), (2, 25), (3, 21),],["id", "age"])

# buckets
from pyspark.ml.feature import Bucketizer

bucketizer = Bucketizer(splits=[20,24,27,30],inputCol="age", outputCol="age_bucket", handleInvalid="keep")
buckets = bucketizer.transform(df)

buckets.show()
+---+---+----------+
| id|age|age_bucket|
+---+---+----------+
|  1| 30|       2.0|
|  2| 25|       1.0|
|  3| 21|       0.0|
+---+---+----------+

# ohe
from pyspark.ml.feature import OneHotEncoderEstimator

encoder = OneHotEncoderEstimator(inputCols=["age_bucket"], outputCols=["age_ohe"])

model = encoder.fit(buckets)
transform_model = model.transform(buckets)

transform_model.show()
+---+---+----------+-------------+
| id|age|age_bucket|      age_ohe|
+---+---+----------+-------------+
|  1| 30|       2.0|    (2,[],[])|
|  2| 25|       1.0|(2,[1],[1.0])|
|  3| 21|       0.0|(2,[0],[1.0])|
+---+---+----------+-------------+

# wrap it up in a pipeline if you want
from pyspark.ml import Pipeline

bucketizer = Bucketizer(splits=[20,24,27,30], inputCol="age", outputCol="age_bucket")
encoder = OneHotEncoderEstimator(inputCols=["age_bucket"], outputCols=["age_ohe"])

pipeline = Pipeline(stages=[bucketizer, encoder])

model = pipeline.fit(df)
fe = model.transform(df)

fe.show()
+---+---+----------+-------------+
| id|age|age_bucket|      age_ohe|
+---+---+----------+-------------+
|  1| 30|       2.0|    (2,[],[])|
|  2| 25|       1.0|(2,[1],[1.0])|
|  3| 21|       0.0|(2,[0],[1.0])|
+---+---+----------+-------------+

答案 1 :(得分:0)

如果您希望获得与问题完全相同的结果,那么OneHotEstimatorEncoder将无法使用其他一些精美的映射技巧。

我将在此处使用联接:

age_buckets = [20, 24, 27, 30]
bins = list(zip(age_buckets, age_buckets[1:]))

data = [[i] + ['({0}-{1}]'.format(*bin_endpoints)] + [0] * i + [1] + [0] * (len(bins) - i - 1) 
        for i, bin_endpoints in enumerate(bins)]
schema = ', '.join('age_bucket_{}_{}: int'.format(start, end) 
                   for start, end in zip(age_buckets, age_buckets[1:]))

join_df = spark.createDataFrame(data, 'age_bucket: int, age_bucket_string: string, ' + schema)

result = (df1.join(join_df, on='age_bucket', how='left')
             .drop('age_bucket')
             .withColumnRenamed('age_bucket_string', 'age_bucket')
             .orderBy('id'))
result.show()

输出:

+---+---+----------+----------------+----------------+----------------+
| id|Age|age_bucket|age_bucket_20_24|age_bucket_24_27|age_bucket_27_30|
+---+---+----------+----------------+----------------+----------------+
|  1| 30|   (27-30]|               0|               0|               1|
|  2| 25|   (24-27]|               0|               1|               0|
|  3| 21|   (20-24]|               1|               0|               0|
+---+---+----------+----------------+----------------+----------------+