有条件在pyspark中爆炸

时间:2020-09-24 18:55:31

标签: apache-spark pyspark hive

我的数据如下

+----------+-----------------------------------+---------------------------------------------------------------------+
|athl_id   |Interest                           |branch                                                               |
+----------+-----------------------------------+---------------------------------------------------------------------+
|123       |Running                            |Running,Outdoor                                                      |
|856       |Running                            |Running                                                              |
|902       |Training,Fitness                   |Fitness,Training                                                     |
|9567      |Swimming,Training,Fitness          |Swimming,Training,Fitness                                            |
|477       |All                                |Running,All,Training,Soccer,Swimming,Fitness,Outdoor,Indoor          |
|490       |Running,Indoor                     |Running,Indoor                                                       |
+----------+-----------------------------------+---------------------------------------------------------------------+

现在我想在以下条件下展开兴趣和分支两个字段。

  1. 对于每个athl_id,完全展开“兴趣”字段
  2. 如果分支的任何逗号分隔值等于“兴趣”的任何逗号分隔值,则完全忽略分支中的那个值并爆炸其余部分。
  3. 如果分支的逗号分隔值不等于“兴趣”的任何逗号分隔值,则爆炸字段分支。

Ex-在上表中,athl_id-902对Training,Fitness有兴趣,并且由于分支值也相同(Fitness,Training),因此预期结果的branch值为null,并且Interest分解为两行。 同样,athl_id-477具有“全部关注”,并且分支的值为“运行,全部,训练,足球,游泳,健身,户外,室内”,但是由于“全部”是关注的一部分,因此爆炸的已归档分支不会不包含“全部”,而是包含其余的“跑步,训练,足球,游泳,健身,室外,室内”

预期结果:


+----------+-----------------------------------+---------------------------------------------------------------------+
|athl_id   |Interest                           |branch                                                               |
+----------+-----------------------------------+---------------------------------------------------------------------+
|123       |Running                            |Outdoor                                                              |
|856       |Running                            |                                                                     |
|902       |Training                           |                                                                     |
|902       |Fitness                            |                                                                     |
|9567      |Swimming                           |                                                                     |
|9567      |Training                           |                                                                     |
|9567      |Fitness                            |                                                                     |
|477       |All                                |Running                                                              |
|477       |All                                |Training                                                             |
|477       |All                                |Soccer                                                               |
|477       |All                                |Swimming                                                             |
|477       |All                                |Fitness                                                              |
|477       |All                                |Outdoor                                                              |
|477       |All                                |Indoor                                                               |
|490       |Running                            |                                                                     |
|490       |Indoor                             |                                                                     |
+----------+-----------------------------------+---------------------------------------------------------------------+

现在,我尝试了以下操作,但遇到错误。而且,我认为“ array_contains”与确切值不匹配。

spark.sql("""  
select athl_id, Interest,
case when array_contains(split(branch,','),Interest) then null
else explode(split(branch,',')) end as branch
from (
select athl_id, explode(split(Interest,',')) as   Interest ,branch from athl_details)a
""").show(100,False )


Traceback (most recent call last):
  File "<stdin>", line 7, in <module>
  File "/usr/lib/spark/python/pyspark/sql/session.py", line 767, in sql
    return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
  File "/usr/lib/spark/python/pyspark/sql/utils.py", line 69, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.AnalysisException: u"Generators are not supported when it's nested in expressions

有人可以建议我正确的方法吗?

谢谢! ?

1 个答案:

答案 0 :(得分:2)

使用Spark版本> = 2.4的array_except函数。

split设置之后,从2列中获取元素差异,并在该列上使用explode_outer

from pyspark.sql.functions import col,explode_outer,array_except,split

split_col_df = df.withColumn('interest_array',split(col('interest'),',')) \
                 .withColumn('branch_array',split(col('branch'),','))
#Get the elements in branch not in interest
tmp_df = split_col_df.withColumn('elem_diff',array_except(col('branch_array'),col('interest_array')))
res = tmp_df.withColumn('interest_expl',explode_outer(col('interest_array'))) \
            .withColumn('branch_expl',explode_outer(col('elem_diff')))

res.select('athl_id','interest_expl','branch_expl').show()

如果branch列中可能有重复项,并且您只想减去相同次数的公共值,则可能必须编写UDF来解决问题。