Pyspark将列表的列转换为嵌套结构列

时间:2018-10-20 13:36:31

标签: python apache-spark pyspark apache-spark-sql user-defined-functions

我正在尝试将一组难看的文本字符串转换为代表性的PySpark数据框。我停留在将包含字符串列表的列转换为包含行的嵌套结构的列的最后一步。对于列表中的每个字符串,我使用python字典理解将其标准化为相同的字段。当我尝试通过列上的udf进行转换时,它失败了。

我的“记录”列包含类似这样的字符串列表...

['field1, field2, field3, field4', 'field1, field2, field3, field4'..]

幸运的是,字符串结构定义明确,包含字符串和整数,因此我有一个Python字典理解功能,可以拆分和分配名称。

def extract_fields(row: str) -> dict:
  fields = row.split(",")
  return { 'field1': fields[0], 'field2': fields[1], ...} 

这对于将单个字符串转换为行的效果很好

from pyspark.sql import Row
Row(**extract_fields( sample_string))

因此,我想我可以使用UDF将列转换为嵌套结构的列。

nest = sqlfn.udf(lambda x: [Row(**extract_fields(row)) for row in x])

通常我会为UDF添加返回的类型,但是我不知道如何指示行数组。在稍后执行之前,我不会收到任何错误。

所以,现在,当我尝试将其应用于数据框时,

test = df.select(nest(df.records).alias('expanded')
test.show(5)

我收到此错误:

expected zero arguments for construction of ClassDict (for 
pyspark.sql.types._create_row)

我发现的与此错误相关的其他问题似乎表明它们的字典中存在类型错误,但就我而言,我的字典的类型为字符串和整数。我还尝试了一个仅包含单个字符串列表的小示例,并获得了相同的答案。

我的预期结果是,新列“扩展”为具有嵌套行结构的列,该列中的单个行类似于:

Row(expanded = [Row(field1='x11', field2='x12',...), Row(field1='x21', 
field2='x22',....) ] )

有什么建议吗?

1 个答案:

答案 0 :(得分:0)

TL; DR pyspark.sql.Row对象无法从udf返回。

已知形状

如果架构定义正确,并且您不会array<struct<...>>作为结果,则应该使用标准的tuple。在这种情况下,基本的解析功能可以这样实现*:

from typing import List, Tuple

def extract_fields(row: str) -> Tuple[str]:
    # Here we assume that each element has the  expected number of fields
    # In practice you should validate the data
    return tuple(row.split(","))

并为udf提供输出模式:

schema = ("array<struct<"
          "field1: string, field2: string, field3: string, field4: string"
          ">>")

@sqlfn.udf(schema)
def extract_multile_fields(rows: List[str]) -> List[Tuple[str]]:
    return [extract_fields(row) for row in rows]

result = df.select(extract_multile_fields("x"))
result.show(truncate=False)
+--------------------------------------------------------------------------+
|extract_multile_fields(x)                                                 |
+--------------------------------------------------------------------------+
|[[field1,  field2,  field3,  field4], [field1,  field2,  field3,  field4]]|
+--------------------------------------------------------------------------+

如果字段的数量很大,那么您可能更喜欢以编程方式构造模式,而不是使用DDL字符串:

from pyspark.sql.types import ArrayType, StringType, StructField, StructType

schema = ArrayType(StructType(
    [StructField(f"field$i", StringType()) for i in range(1, 5)]
))

在Spark 2.4或更高版本中,您还可以直接使用内置函数:

from pyspark.sql.column import Column

def extract_multile_fields_(col: str) -> Column:
    return sqlfn.expr(f"""transform(
        -- Here we parameterize input with {{}}
        transform(`{col}`, s -> split(s, ',')),  
        -- Adjust the list of fields and cast if necessary 
        a -> struct(
            a[0] as field1, a[1] as field2, a[2] as field3, a[3] as field4)
    )""")


result = df.select(extract_multile_fields_("x").alias("records"))
result.show(truncate=False)
+--------------------------------------------------------------------------+
|records                                                                   |
+--------------------------------------------------------------------------+
|[[field1,  field2,  field3,  field4], [field1,  field2,  field3,  field4]]|
+--------------------------------------------------------------------------+

未知形状

如果数据的形状未知,则array<struct<...>>不是DataType的正确选择。在这种情况下,您可以尝试使用array<map<..., ...>>,但这要求所有值都具有相同的类型:

from typing import Dict

def extract_fields(row: str) -> Dict[str, str]:
    return ... # TODO: Implement the logic

@sqlfn.udf("array<map<string, string>>")
def extract_multile_fields(rows: List[str]) -> List[Dict[str, str]]:
    return [extract_fields(row) for row in rows]

*请注意,所有记录必须具有相同的形状。如果某些字段丢失。您应该用None填补空白。