我有一个pyspark.sql.dataframe.DataFrame
,其中一列中包含Row
个对象的数组:
+------------------------------------------------------------------------------------------------+
|column |
+------------------------------------------------------------------------------------------------+
|[Row(arrival='2019-12-25 19:55', departure='2019-12-25 18:22'), |
| Row(arrival='2019-12-26 14:56', departure='2019-12-26 08:52')] |
+------------------------------------------------------------------------------------------------+
并非该列中的所有行都具有相同数量的元素(在这种情况下,我们有2个,但我们可以有更多个)。
我想做的是生成每个日期的小时的级联,这样的事情:
18:22_19:55_08:52_14:56
这意味着,第一个元素的出发时间与第一个元素的到达时间相关,再次与出发第二个元素的时间,并再次显示第二个元素的到达时间。
是否有使用pyspark
的简单方法?
答案 0 :(得分:0)
假设列名是col1
,它是一个结构数组:
df.printSchema()
root
|-- col1: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- arrival: string (nullable = true)
| | |-- departure: string (nullable = true)
方法1:对于Spark 2.4+,请使用array_join + transform
from pyspark.sql.functions import expr
df.withColumn('new_list', expr("""
array_join(
transform(col1, x -> concat(right(x.departure,5), '_', right(x.arrival,5)))
, '_'
)
""")
).show(truncate=False)
+----------------------------------------------------------------------------+-----------------------+
|col1 |new_list |
+----------------------------------------------------------------------------+-----------------------+
|[[2019-12-25 19:55, 2019-12-25 18:22], [2019-12-26 14:56, 2019-12-26 08:52]]|18:22_19:55_08:52_14:56|
+----------------------------------------------------------------------------+-----------------------+
方法2:使用udf:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
def arrays_join(arr):
return '_'.join('{}_{}'.format(x.departure[-5:], x.arrival[-5:]) for x in arr) if isinstance(arr, list) else arr
udf_array_join = udf(arrays_join, StringType())
df.select(udf_array_join('col1')).show(truncate=False)
方法3:使用posexplode + groupby + collect_list:
from pyspark.sql.functions import monotonically_increasing_id, posexplode, regexp_replace, expr
(df.withColumn('id', monotonically_increasing_id())
.select('*', posexplode('col1').alias('pos', 'col2'))
.select('id', 'pos', 'col2.*')
.selectExpr('id', "concat(pos, '+', right(departure,5), '_', right(arrival,5)) as dt")
.groupby('id')
.agg(expr("concat_ws('_', sort_array(collect_list(dt))) as new_list"))
.select(regexp_replace('new_list', r'(?:^|(?<=_))\d+\+', '').alias('new_list'))
.show(truncate=False))
方法4:使用字符串操作:
仅针对此特定问题,将数组转换为字符串,然后执行一堆字符串操作(split + concat_ws + regexp_replace + trim)以获得所需的子字符串:
from pyspark.sql.functions import regexp_replace, concat_ws, split, col
(df.select(
regexp_replace(
concat_ws('_', split(col('col1').astype('string'), r'[^0-9 :-]+'))
, r'[_ ]+\d\d\d\d-\d\d-\d\d '
, '_'
).alias('new_list')
).selectExpr('trim(both "_" from new_list) as new_list')
.show(truncate=False))