如何在PySpark ML中创建自定义SQLTransformer以透视数据

时间:2018-08-23 00:55:10

标签: python apache-spark pyspark apache-spark-mllib

我有一个类似于以下结构的数据框:

# Prepare training data
training = spark.createDataFrame([
    (990011, 1001, 01, "Salary", 1000, 0.0),
    (990011, 1002, 02, "POS Purchase", 50, 0.0),
    (990022, 1003, 01, "Cash Withdrawl", 500, 1.0),
    (990022, 1004, 02, "Interest Charge", 35, 1.0)
], ["customer_id", "transaction_id", "week_of_year", "category", "amount", "label"])

我能够使用以下PySpark动态地旋转数据,从而消除了每周和每个类别的硬编码案例说明:

# Attempt 1
tx_pivot = training \
    .withColumn("week_of_year", sf.concat(sf.lit("T"), sf.col("week_of_year"))) \
    .groupBy("customer_id") \
    .pivot("week_of_year") \
    .sum("amount")

tx_pivot.show(20)

我想开发一个自定义的Transformer来动态地旋转数据,因此我可以将此自定义的Transform阶段合并到Spark ML Pipeline中。不幸的是,Spark / PySpark中当前的SQLTransfomer仅支持SQL,例如E.g. “选择...来自”(请参阅​​https://github.com/apache/spark/blob/master/python/pyspark/ml/feature.py)。

任何有关如何创建自定义Transformer来动态旋转数据的指南都将不胜感激。

1 个答案:

答案 0 :(得分:1)

实现自定义转换器很容易,它接受一个数据帧并返回另一个数据帧。就您而言:

import pyspark.ml.pipeline.Transformer as Transformer

class PivotTransformer(Transformer):

    def _transform(self, data):           
        return data.withColumn("week_of_year",sf.concat(sf.lit("T"),\
                    sf.col("week_of_year"))) \
                   .groupBy("customer_id") \
                   .pivot("week_of_year") \
                   .sum("amount")