有效地扩展Row数组以分隔列

时间:2018-03-12 13:50:29

标签: apache-spark pyspark spark-dataframe user-defined-functions

我有一个spark数据帧,其中一个字段是Row结构数组。我需要将它扩展到自己的列中。其中一个问题是在数组中,有时缺少字段。

以下是一个例子:

from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql import functions as udf

spark = SparkSession.builder.getOrCreate()

# data
rows = [{'status':'active','member_since':1990,'info':[Row(tag='name',value='John'),Row(tag='age',value='50'),Row(tag='phone',value='1234567')]},
        {'status':'inactive','member_since':2000,'info':[Row(tag='name',value='Tom'),Row(tag='phone',value='1234567')]},
        {'status':'active','member_since':2015,'info':[Row(tag='name',value='Steve'),Row(tag='age',value='28')]}]

# create dataframe
df = spark.createDataFrame(rows)

# transform info to dict
to_dict = udf.UserDefinedFunction(lambda s:dict(s),MapType(StringType(),StringType()))
df = df.withColumn("info_dict",to_dict("info"))

# extract name, NA if not exists
extract_name = udf.UserDefinedFunction(lambda s:s.get("name","NA"))
df = df.withColumn("name",extract_name("info_dict"))

# extract age, NA if not exists
extract_age = udf.UserDefinedFunction(lambda s:s.get("age","NA"))
df = df.withColumn("age",extract_age("info_dict"))

# extract phone, NA if not exists
extract_phone = udf.UserDefinedFunction(lambda s:s.get("phone","NA"))
df = df.withColumn("phone",extract_phone("info_dict"))

df.show()

你可以看到汤姆'年龄'不见了;对于史蒂夫','电话'不见了。与上面的代码片段一样,我目前的解决方案是首先将数组转换为dict,然后将每个单独的字段解析为其列。结果是这样的:

+--------------------+------------+--------+--------------------+-----+---+-------+
|                info|member_since|  status|           info_dict| name|age|  phone|
+--------------------+------------+--------+--------------------+-----+---+-------+
|[[name, John], [a...|        1990|  active|[name -> John, ph...| John| 50|1234567|
|[[name, Tom], [ph...|        2000|inactive|[name -> Tom, pho...|  Tom| NA|1234567|
|[[name, Steve], [...|        2015|  active|[name -> Steve, a...|Steve| 28|     NA|
+--------------------+------------+--------+--------------------+-----+---+-------+

我真的只想要列' status',' member_since',' name',' age'和#'电话'。由于UDF,此解决方案有效但速度很慢。还有更快的选择吗?感谢

1 个答案:

答案 0 :(得分:0)

我可以想到使用DataFrame函数的两种方法。我相信第一个应该更快,但代码不那么优雅。第二个更紧凑,但可能更慢。

方法1:动态创建地图

此方法的核心是将您的Row变为MapType()。这可以使用pyspark.sql.functions.create_map()使用functools.reduce()operator.add()来实现。

from operator import add
import pyspark.sql.functions as f

f.create_map(
    *reduce(
        add,
        [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
         for k in range(3)]
    )
)

问题在于,没有办法(AFAIK)以简单的方式动态确定WrappedArray或iterate through的长度。如果缺少值,则会导致错误,因为映射键不能是null。但是,由于我们知道列表可以包含1,2,3个元素,因此我们可以测试每种情况。

df.withColumn(
    'map',
    f.when(f.size(f.col('info')) == 1, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(1)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 2, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(2)]
            )
        )
    ).otherwise(
    f.when(f.size(f.col('info')) == 3, 
        f.create_map(
            *reduce(
                add,
                [[f.col('info')['tag'].getItem(k), f.col('info')['value'].getItem(k)]
                 for k in range(3)]
            )
        )
    )))
).select(
    ['member_since', 'status'] + [f.col("map").getItem(k).alias(k) for k in keys]
).show(truncate=False)

最后一步使用this answer中描述的方法将'map'键转换为列。

这会产生以下输出:

+------------+--------+-----+----+-------+
|member_since|status  |name |age |phone  |
+------------+--------+-----+----+-------+
|1990        |active  |John |50  |1234567|
|2000        |inactive|Tom  |null|1234567|
|2015        |active  |Steve|28  |null   |
+------------+--------+-----+----+-------+

方法2:使用explode,groupBy和pivot

首先在'info'列上使用pyspark.sql.functions.explode(),然后使用'tag''value'列作为create_map()的参数:

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', 'map')\
    .show(truncate=False)
#+------------+------------+--------+---------------------+
#|id          |member_since|status  |map                  |
#+------------+------------+--------+---------------------+
#|85899345920 |1990        |active  |Map(name -> John)    |
#|85899345920 |1990        |active  |Map(age -> 50)       |
#|85899345920 |1990        |active  |Map(phone -> 1234567)|
#|180388626432|2000        |inactive|Map(name -> Tom)     |
#|180388626432|2000        |inactive|Map(phone -> 1234567)|
#|266287972352|2015        |active  |Map(name -> Steve)   |
#|266287972352|2015        |active  |Map(age -> 28)       |
#+------------+------------+--------+---------------------+

我还使用pyspark.sql.functions.monotonically_increasing_id()添加了一列'id',以确保我们可以跟踪哪些行属于同一记录。

现在我们可以展开地图列groupBy()pivot()。我们可以使用pyspark.sql.functions.first()作为groupBy()的汇总函数,因为我们知道每个组中只有一个'value'

df.withColumn('id', f.monotonically_increasing_id())\
    .withColumn('exploded', f.explode(f.col('info')))\
    .withColumn(
        'map', 
        f.create_map(*[f.col('exploded')['tag'], f.col('exploded')['value']]).alias('map')
    )\
    .select('id', 'member_since', 'status', f.explode('map'))\
    .groupBy('id', 'member_since', 'status').pivot('key').agg(f.first('value'))\
    .select('member_since', 'status', 'age', 'name', 'phone')\
    .show()
#+------------+--------+----+-----+-------+
#|member_since|  status| age| name|  phone|
#+------------+--------+----+-----+-------+
#|        1990|  active|  50| John|1234567|
#|        2000|inactive|null|  Tom|1234567|
#|        2015|  active|  28|Steve|   null|
#+------------+--------+----+-----+-------+