以Row对象格式访问数组的元素并将其连接-pySpark

时间:2019-12-17 12:50:17

标签: python pyspark

我有一个pyspark.sql.dataframe.DataFrame,其中一列中包含Row个对象的数组:

    +------------------------------------------------------------------------------------------------+
    |column                                                                          |
    +------------------------------------------------------------------------------------------------+
    |[Row(arrival='2019-12-25 19:55', departure='2019-12-25 18:22'),                                 |
    |  Row(arrival='2019-12-26 14:56', departure='2019-12-26 08:52')]                                |
    +------------------------------------------------------------------------------------------------+

并非该列中的所有行都具有相同数量的元素(在这种情况下,我们有2个,但我们可以有更多个)。

我想做的是生成每个日期的小时的级联,这样的事情:

18:22_19:55_08:52_14:56

这意味着,第一个元素的出发时间与第一个元素的到达时间相关,再次与出发第二个元素的时间,并再次显示第二个元素的到达时间。

是否有使用pyspark的简单方法?

1 个答案:

答案 0 :(得分:0)

假设列名是col1,它是一个结构数组:

df.printSchema()                                                                                                    
root
 |-- col1: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- arrival: string (nullable = true)
 |    |    |-- departure: string (nullable = true)

方法1:对于Spark 2.4+,请使用array_join + transform

from pyspark.sql.functions import expr

df.withColumn('new_list', expr("""
    array_join(
        transform(col1, x -> concat(right(x.departure,5), '_', right(x.arrival,5)))
      , '_'
    )
  """) 
).show(truncate=False)

+----------------------------------------------------------------------------+-----------------------+
|col1                                                                        |new_list               |
+----------------------------------------------------------------------------+-----------------------+
|[[2019-12-25 19:55, 2019-12-25 18:22], [2019-12-26 14:56, 2019-12-26 08:52]]|18:22_19:55_08:52_14:56|
+----------------------------------------------------------------------------+-----------------------+

方法2:使用udf:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

def arrays_join(arr):
    return '_'.join('{}_{}'.format(x.departure[-5:], x.arrival[-5:]) for x in arr) if isinstance(arr, list) else arr    

udf_array_join = udf(arrays_join, StringType())

df.select(udf_array_join('col1')).show(truncate=False)

方法3:使用posexplode + groupby + collect_list:

from pyspark.sql.functions import monotonically_increasing_id, posexplode, regexp_replace, expr

(df.withColumn('id', monotonically_increasing_id()) 
    .select('*', posexplode('col1').alias('pos', 'col2')) 
    .select('id', 'pos', 'col2.*') 
    .selectExpr('id', "concat(pos, '+', right(departure,5), '_', right(arrival,5)) as dt") 
    .groupby('id') 
    .agg(expr("concat_ws('_', sort_array(collect_list(dt))) as new_list")) 
    .select(regexp_replace('new_list', r'(?:^|(?<=_))\d+\+', '').alias('new_list')) 
    .show(truncate=False))

方法4:使用字符串操作:

仅针对此特定问题,将数组转换为字符串,然后执行一堆字符串操作(split + concat_ws + regexp_replace + trim)以获得所需的子字符串:

from pyspark.sql.functions import regexp_replace, concat_ws, split, col

(df.select(
    regexp_replace(
        concat_ws('_', split(col('col1').astype('string'), r'[^0-9 :-]+'))
      , r'[_ ]+\d\d\d\d-\d\d-\d\d '
      , '_'
    ).alias('new_list')
).selectExpr('trim(both "_" from new_list) as new_list') 
.show(truncate=False))