将数据框的行嵌套为数组列

时间:2018-08-16 12:05:58

标签: pyspark apache-spark-sql pyspark-sql

我有一个看起来或多或少像这样的数据框:

| id | category | value | item_id |
|----|----------|-------|---------|
| 1  | 1        | 1     | 1       |
| 2  | 2        | 2     | 2       |
| 3  | 1        | 3     | 1       |
| 4  | 2        | 4     | 2       |

在这种情况下,某些类别必须在代码的某些部分中视为子类别(到目前为止的计算是相似的,与层次结构无关,因此它们都在同一表中)。但是,现在必须根据在单独的数据框中定义的一组特定规则来嵌套它们:

| id | children |
|----|----------|
| 1  | [2,3]    |
| 2  | null     |
| 3  | null     |

此嵌套取决于item列。也就是说,对于每一行,仅必须嵌套具有子类别的具有相同项目值的那些条目。这意味着项1的类别2和3必须嵌套在ID为1的条目下。如果输出为JSON,结果应如下所示:

[{
    "id": 1,
    "item": 1,
    "category": 1,
    "value": 1,
    "children": [{
        "id": 2,
        "item": 1,
        "category": 2,
        "value": 2
    }, {
        "id": 3,
        "item": 1,
        "category": 3,
        "value": 3
    }]
},

{
    "id": 4,
    "item": 2,
    "category": 1,
    "value": 4,
    "children": []
}]

虽然使用我自己的代码实现起来非常简单,但我想使用PySpark Dataframe API实现此嵌套。到目前为止,这是我尝试过的:

联接两个表,以便将某一行的子级列表添加为列:

df_data = df_data.join(df_rules, df_data.category == df_rules.id, "left")

经过合并后,结果如下:

| id | category | value | children |
|----|----------|-------|----------|
| 1  | 1        | 1     | [2,3]    |
| 2  | 2        | 2     | []       |
| 3  | 1        | 3     | []       |
| 4  | 2        | 4     | []       |

现在,我想应用某种转换,以便得到类似这样的内容:

| id | category | value | item | children                |
|----|----------|-------|------|-------------------------|
| 1  | 1        | 1     | 1    |[(2,2,2,1),(3,3,3,1)]    |
| 2  | 2        | 2     | 1    |[]                       |
| 3  | 1        | 3     | 1    |[]                       |
| 4  | 2        | 4     | 1    |[]                       |

也就是说,将ID为2和3的行嵌套到第1行中。其余的将收到一个空列表,因为没有匹配项。之后,可以删除子类别,但是实现起来很简单。

我正在努力实现这一目标。我的第一个想法是使用这样的东西:

spark.sql("SELECT *, ARRAY(SELECT * FROM my_table b WHERE b.item = a.item AND b.category IN a.children) FROM my_table a")

但是,一旦我向SELECT添加ARRAY语句,它就会抱怨。我也考虑过窗口函数或UDF,但是我不确定如何进行处理,或者甚至可能。

1 个答案:

答案 0 :(得分:0)

我认为找到了一种方法。这是一个使用熊猫创建数据的MCVE。假定已经初始化了Spark会话

import pandas as pd
from pyspark.sql import functions as F
columns = ['id', 'category', 'value', 'item_id']
data = [(1,1,1,1), (2,2,2,2), (3,3,3,1), (4,4,4,2)]

spark_data = spark.createDataFrame(pd.DataFrame.from_records(data=data, columns=columns))

rules_columns = ['category', 'children_rules']
rules_data = [(1, [2, 3]), (2, []), (3, [])]

spark_rules_data = spark.createDataFrame(pd.DataFrame.from_records(data=rules_data, columns=rules_columns))

首先,使用要应用的规则执行左连接:

joined_data = spark_data.join(spark_rules_data, on="category", how="left")
joined_data = joined_data.withColumn("children_rules", F.coalesce(F.col("children_rules"), F.array()))

joined_data.createOrReplaceTempView("joined_data")


joined_data.show()


+--------+---+-----+-------+--------------+
|category| id|value|item_id|children_rules|
+--------+---+-----+-------+--------------+
|       1|  1|    1|      1|        [2, 3]|
|       3|  3|    3|      1|            []|
|       2|  2|    2|      2|            []|
|       4|  4|    4|      2|            []|
+--------+---+-----+-------+--------------+

根据children列中的规则将表自身连接起来

nested_data = spark.sql("""SELECT joined_data_1.id as id, joined_data_1.category as category, joined_data_1.value as value, joined_data_1.item_id as item_id,
             STRUCT(joined_data_2.id as id, joined_data_2.category as category, joined_data_2.value as value, joined_data_2.item_id as item_id) as children
             FROM joined_data AS joined_data_1 LEFT JOIN joined_data AS joined_data_2 
                 ON array_contains(joined_data_1.children_rules, joined_data_2.category)""")
nested_data.createOrReplaceTempView("nested_data")
nested_data.show()

+---+--------+-----+-------+------------+
| id|category|value|item_id|    children|
+---+--------+-----+-------+------------+
|  1|       1|    1|      1|[3, 3, 3, 1]|
|  1|       1|    1|      1|[2, 2, 2, 2]|
|  3|       3|    3|      1|       [,,,]|
|  2|       2|    2|      2|       [,,,]|
|  4|       4|    4|      2|       [,,,]|
+---+--------+-----+-------+------------+

按类别值分组,并将children列汇总到列表中

grouped_data = spark.sql("SELECT category, collect_set(children) as children FROM nested_data GROUP BY category")
grouped_data.createOrReplaceTempView("grouped_data")
grouped_data.show()

+--------+--------------------+
|category|            children|
+--------+--------------------+
|       1|[[2, 2, 2, 2], [3...|
|       3|             [[,,,]]|
|       2|             [[,,,]]|
|       4|             [[,,,]]|
+--------+--------------------+

与原始表一起加入分组表

original_with_children = spark_data.join(grouped_data, on="category")
original_with_children.createOrReplaceTempView("original_with_children")
original_with_children.show()

+--------+---+-----+-------+--------------------+
|category| id|value|item_id|            children|
+--------+---+-----+-------+--------------------+
|       1|  1|    1|      1|[[2, 2, 2, 2], [3...|
|       3|  3|    3|      1|             [[,,,]]|
|       2|  2|    2|      2|             [[,,,]]|
|       4|  4|    4|      2|             [[,,,]]|
+--------+---+-----+-------+--------------------+

这是棘手的问题。我们需要使用children值删除NULL中的条目。我试图用一个空数组做一个CASE语句,强制转换为'array<struct<id:bigint,category:bigint,value:bigint,item_id:bigint>>(此值来自original_with_children.dtypes

[('category', 'bigint'),
 ('id', 'bigint'),
 ('value', 'bigint'),
 ('item_id', 'bigint'),
 ('children',
  'array<struct<id:bigint,category:bigint,value:bigint,item_id:bigint>>')]

array_type = "array<struct<id:bigint,category:bigint,value:bigint,item_id:bigint>>"
spark.sql(f"""SELECT *, CASE WHEN children[0]['category'] IS NULL THEN CAST(ARRAY() AS {array_type}) ELSE children END as no_null_children 
        FROM original_with_children""").show()

这会引发以下异常(仅显示相关位):


Py4JJavaError                             Traceback (most recent call last)

~/miniconda3/envs/sukiyaki_venv/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:


~/miniconda3/envs/sukiyaki_venv/lib/python3.7/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:


Py4JJavaError: An error occurred while calling o28.sql.
: org.apache.spark.sql.AnalysisException: cannot resolve 'array()' due to data type mismatch: cannot cast array<string> to array<struct<id:bigint,category:bigint,value:bigint,item_id:bigint>>; line 1 pos 57;
...

我无法找到一种创建正确类型的空数组的方法(广播不起作用,因为默认值,字符串数组和结构数组之间没有转换)。相反,这是我的方法:

字段的顺序在每次调用时都会发生变化,这会令人惊讶地导致类型不匹配,因此每次都需要查询     array_type = next(密钥的值,如果key =='children',则为original_with_children.dtypes中的值)

empty_array_udf = F.udf(lambda : [], array_type)

aux = original_with_children.withColumn("aux", empty_array_udf())
aux.createOrReplaceTempView("aux")

必须有更好的方法来创建具有这种复杂类型的空列。 UDF会给这种简单的事情带来不必要的开销。

no_null_children = spark.sql("""SELECT *, CASE WHEN children[0]['category'] IS NULL THEN aux ELSE children END as no_null_children 
          FROM aux""")
no_null_children.createOrReplaceTempView("no_null_children")
no_null_children.show()

+--------+---+-----+-------+--------------------+---+--------------------+
|category| id|value|item_id|            children|aux|    no_null_children|
+--------+---+-----+-------+--------------------+---+--------------------+
|       1|  1|    1|      1|[[2, 2, 2, 2], [3...| []|[[2, 2, 2, 2], [3...|
|       3|  3|    3|      1|             [[,,,]]| []|                  []|
|       2|  2|    2|      2|             [[,,,]]| []|                  []|
|       4|  4|    4|      2|             [[,,,]]| []|                  []|
+--------+---+-----+-------+--------------------+---+--------------------+

删除不必要的列:

result = no_null_children.drop("aux").drop("children").withColumnRenamed("no_null_children", "children")

从顶层删除嵌套条目

nested_categories = spark.sql("""SELECT explode(children['category']) as category FROM removed_columns""")
nested_categories.createOrReplaceTempView("nested_categories")
nested_categories.show()

+--------+
|category|
+--------+
|       2|
|       3|
+--------+


result = spark.sql("SELECT * from removed_columns WHERE category NOT IN (SELECT category FROM nested_categories)")
result.show()

+--------+---+-----+-------+--------------------+
|category| id|value|item_id|            children|
+--------+---+-----+-------+--------------------+
|       1|  1|    1|      1|[[2, 2, 2, 2], [3...|
|       4|  4|    4|      2|                  []|
+--------+---+-----+-------+--------------------+

最终的JSON结果如预期:

result.toJSON().collect()


['{"category":1,"id":1,"value":1,"item_id":1,"children":[{"id":2,"category":2,"value":2,"item_id":2},{"id":3,"category":3,"value":3,"item_id":1}]}',
 '{"category":4,"id":4,"value":4,"item_id":2,"children":[]}']

美化:

{
   "category":1,
   "id":1,
   "value":1,
   "item_id":1,
   "children":[
      {
         "id":2,
         "category":2,
         "value":2,
         "item_id":2
      },
      {
         "id":3,
         "category":3,
         "value":3,
         "item_id":1
      }
   ]
}

{
   "category":4,
   "id":4,
   "value":4,
   "item_id":2,
   "children":[

   ]
}