我正在尝试在spark中实现标量pandas_udf,但是在执行特定操作时遇到错误。以下是我编写的有关列和udf结构的详细信息:
dataframe schema for array type column:
list_col1: array (nullable = true)
| |-- element: string (containsNull = true)
from pyspark.sql import functions as F
from pyspark.sql.functions import udf, flatten, pandas_udf
from pyspark.sql.types import ArrayType, StringType, TimestampType
from pyspark.sql import Row
@pandas_udf(ArrayType(StringType()), PandasUDFType.SCALAR)
def truncate_data_udf(list_type_col, output_list_length):
sortedList=pd.Series(list_type_col).tolist()
un_list=list(OrderedDict.fromkeys(sortedList))
trunc_size=int(output_list_length)
if len(un_list)>trunc_size:
un_list=un_list[:trunc_size]
un_list.insert(0, 'truncated')
return pd.Series(un_list)
df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1")), lit(10)))
Expected result is truncated list having elements equal to 10.
因此,将什么格式或数据类型的输入传递给pandas_udf。如果我想将输入列数据转换为列表,那我该怎么做。在返回数据集的同时,如何以列表形式返回结果。
The result column should have schema like:
list_col1: array (nullable = true)
| |-- element: string (containsNull = true)
我还编写了一个如下所示的常规udf,它可以按预期工作。但是我想确定常规和pandas_udf之间在表现方面的差异。我相信pandas_udf比普通的udf快得多。
Normal udf:
def truncate_data(list_type_col, output_list_length):
l= list(OrderedDict.fromkeys(list_type_col))
if l is not None and len(l) > output_list_length:
l = l[:output_list_length]
l.insert(0, 'truncated')
return(l)
truncate_data_udf= udf(lambda row: truncate_data(row, output_list_length), ArrayType(StringType()))
df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1"))))