Pandas_udf问题:将函数应用于数据为ArrayType的每一行

时间:2019-04-08 14:10:40

标签: python pandas pyspark

我有一个pyspark数据框,其中多列包含不同长度的数组。我想遍历相关的列并剪切每一行中的数组,以使它们的长度相同。在此示例中,长度为3。

这是一个示例数据框:

id_1|id_2|id_3|        timestamp     |thing1       |thing2       |thing3
A   |b  |  c |[time_0,time_1,time_2]|[1.2,1.1,2.2]|[1.3,1.5,2.6|[2.5,3.4,2.9]
A   |b  |  d |[time_0,time_1]       |[5.1,6.1, 1.4, 1.6]    |[5.5,6.2, 0.2]   |[5.7,6.3]
A   |b  |  e |[time_0,time_1]       |[0.1,0.2, 1.1]    |[0.5,0.3, 0.3]   |[0.9,0.6, 0.9, 0.4]

到目前为止,我有

 def clip_func(x, ts_len, backfill=1500):
     template = [backfill]*ts_len
     template[-len(x):] = x
     x = template
     return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, 3, missing_fill))

但是不起作用。如果数组太短,我想用missing_fill值填充数组。

1 个答案:

答案 0 :(得分:1)

您的错误是由于将3missing_fill作为python文字传递给clip引起的。如this answer中所述,将udf的输入转换为列。

您应该改为传递列文字。

这是DataFrame的简化示例:

example.show(truncate=False)
#+---+------------------------+--------------------+---------------+--------------------+
#|id |timestamp               |thing1              |thing2         |thing3              |
#+---+------------------------+--------------------+---------------+--------------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]     |[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]     |
#|B  |[time_0, time_1]        |[5.1, 6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[5.7, 6.3]          |
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]     |[0.5, 0.3, 0.3]|[0.9, 0.6, 0.9, 0.4]|
#+---+------------------------+--------------------+---------------+--------------------+

您只需要对传递给udf的参数做一个小的更改:

from pyspark.sql.functions import lit, udf

def clip_func(x, ts_len, backfill):
    template = [backfill]*ts_len
    template[-len(x):] = x
    x = template
    return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

ans = example
for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, lit(3), lit(missing_fill)))

ans.show(truncate=False)
#+---+------------------------+---------------+---------------+---------------+
#|id |timestamp               |thing1         |thing2         |thing3         |
#+---+------------------------+---------------+---------------+---------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]|[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]|
#|B  |[time_0, time_1]        |[6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[3.3, 5.7, 6.3]|
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]|[0.5, 0.3, 0.3]|[0.6, 0.9, 0.4]|
#+---+------------------------+---------------+---------------+---------------+

您的udf当前写为:

  • 当数组长于ts_len时,它将从开头(左侧)截断数组。
  • 当数组短于ts_len时,它将在数组的开头附加missing_fill