我有一个pyspark数据框,其中多列包含不同长度的数组。我想遍历相关的列并剪切每一行中的数组,以使它们的长度相同。在此示例中,长度为3。
这是一个示例数据框:
id_1|id_2|id_3| timestamp |thing1 |thing2 |thing3
A |b | c |[time_0,time_1,time_2]|[1.2,1.1,2.2]|[1.3,1.5,2.6|[2.5,3.4,2.9]
A |b | d |[time_0,time_1] |[5.1,6.1, 1.4, 1.6] |[5.5,6.2, 0.2] |[5.7,6.3]
A |b | e |[time_0,time_1] |[0.1,0.2, 1.1] |[0.5,0.3, 0.3] |[0.9,0.6, 0.9, 0.4]
到目前为止,我有
def clip_func(x, ts_len, backfill=1500):
template = [backfill]*ts_len
template[-len(x):] = x
x = template
return x[-1 * ts_len:]
clip = udf(clip_func, ArrayType(DoubleType()))
for c in [x for x in example.columns if 'thing' in x]:
missing_fill = 3.3
ans = ans.withColumn(c, clip(c, 3, missing_fill))
但是不起作用。如果数组太短,我想用missing_fill值填充数组。
答案 0 :(得分:1)
您的错误是由于将3
和missing_fill
作为python文字传递给clip
引起的。如this answer中所述,将udf
的输入转换为列。
您应该改为传递列文字。
这是DataFrame的简化示例:
example.show(truncate=False)
#+---+------------------------+--------------------+---------------+--------------------+
#|id |timestamp |thing1 |thing2 |thing3 |
#+---+------------------------+--------------------+---------------+--------------------+
#|A |[time_0, time_1, time_2]|[1.2, 1.1, 2.2] |[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9] |
#|B |[time_0, time_1] |[5.1, 6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[5.7, 6.3] |
#|C |[time_0, time_1] |[0.1, 0.2, 1.1] |[0.5, 0.3, 0.3]|[0.9, 0.6, 0.9, 0.4]|
#+---+------------------------+--------------------+---------------+--------------------+
您只需要对传递给udf
的参数做一个小的更改:
from pyspark.sql.functions import lit, udf
def clip_func(x, ts_len, backfill):
template = [backfill]*ts_len
template[-len(x):] = x
x = template
return x[-1 * ts_len:]
clip = udf(clip_func, ArrayType(DoubleType()))
ans = example
for c in [x for x in example.columns if 'thing' in x]:
missing_fill = 3.3
ans = ans.withColumn(c, clip(c, lit(3), lit(missing_fill)))
ans.show(truncate=False)
#+---+------------------------+---------------+---------------+---------------+
#|id |timestamp |thing1 |thing2 |thing3 |
#+---+------------------------+---------------+---------------+---------------+
#|A |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]|[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]|
#|B |[time_0, time_1] |[6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[3.3, 5.7, 6.3]|
#|C |[time_0, time_1] |[0.1, 0.2, 1.1]|[0.5, 0.3, 0.3]|[0.6, 0.9, 0.4]|
#+---+------------------------+---------------+---------------+---------------+
您的udf
当前写为:
ts_len
时,它将从开头(左侧)截断数组。ts_len
时,它将在数组的开头附加missing_fill
。