如何在数组类型列上的pyspark中使用标量pandas_udf

时间:2019-03-21 21:50:49

标签: python pandas pyspark

我正在尝试在spark中实现标量pandas_udf,但是在执行特定操作时遇到错误。以下是我编写的有关列和udf结构的详细信息:

dataframe schema for array type column:

list_col1: array (nullable = true)
 |    |-- element: string (containsNull = true)

from pyspark.sql import functions as F
from pyspark.sql.functions import udf, flatten, pandas_udf
from pyspark.sql.types import ArrayType, StringType, TimestampType
from pyspark.sql import Row


@pandas_udf(ArrayType(StringType()), PandasUDFType.SCALAR)
def truncate_data_udf(list_type_col, output_list_length):     
    sortedList=pd.Series(list_type_col).tolist()    
    un_list=list(OrderedDict.fromkeys(sortedList))
    trunc_size=int(output_list_length)   
    if len(un_list)>trunc_size:
        un_list=un_list[:trunc_size]
        un_list.insert(0, 'truncated')

    return pd.Series(un_list)

df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1")), lit(10)))

Expected result is truncated list having elements equal to 10.

因此,将什么格式或数据类型的输入传递给pandas_udf。如果我想将输入列数据转换为列表,那我该怎么做。在返回数据集的同时,如何以列表形式返回结果。

The result column should have schema like:
list_col1: array (nullable = true)
 |    |-- element: string (containsNull = true)

我还编写了一个如下所示的常规udf,它可以按预期工作。但是我想确定常规和pandas_udf之间在表现方面的差异。我相信pandas_udf比普通的udf快得多。

Normal udf:

def truncate_data(list_type_col, output_list_length): 
    l= list(OrderedDict.fromkeys(list_type_col))
    if l is not None and len(l) > output_list_length:
        l = l[:output_list_length]        
        l.insert(0, 'truncated')    
    return(l)

truncate_data_udf= udf(lambda row: truncate_data(row, output_list_length), ArrayType(StringType()))

df = df.withColumn("list_col", truncate_data_udf(flatten(F.col("list_col1"))))

0 个答案:

没有答案