Pyspark pandas_udf内部不正确的ArrayType元素

时间:2018-07-20 09:24:13

标签: apache-spark pyspark apache-spark-sql user-defined-functions

我正在使用Spark 2.3.0,并在我的Pyspark代码中尝试了pandas_udf用户定义的函数。根据{{​​3}},当前支持ArrayType。我的用户定义函数是:

def transform(c):
    if not any(isinstance(x, (list, tuple, np.ndarray)) for x in c.values):
        nvalues = c.values
    else:
        nvalues = np.array(c.values.tolist())
    tvalues = some_external_function(nvalues)
    if not any(isinstance(y, (list, tuple, np.ndarray)) for y in tvalues):
        p = pd.Series(np.array(tvalues))
    else:
        p = pd.Series(list(tvalues))
    return p

transform = pandas_udf(transform, ArrayType(LongType()))

当我将此功能应用于大型Spark数据框的特定数组列时,我注意到p系列的第一个元素与其他元素相比具有不同的两倍大小,而最后一个元素的大小为0:

0       [73, 10, 223, 46, 14, 73, 14, 5, 14, 21, 10, 2...
1                [223, 46, 14, 73, 14, 5, 14, 21, 30, 16]
2                 [46, 14, 73, 14, 5, 14, 21, 30, 16, 15]
...
4695                                                   []
Name: _70, Length: 4696, dtype: object

在第一个数组包含20个元素的情况下,第二个数组包含10个元素(正确的元素),最后一个为0。然后,由于数组具有多个大小,因此c.values当然以ValueError: setting an array element with a sequence.失败

当我尝试使用相同的函数对带有字符串数组的列进行排序时,则所有大小都是正确的,其余函数也将执行。

任何想法可能是什么问题?可能的错误?

已更新示例:

我将附上一个简单的示例,仅在pandas_udf函数内部打印值。

from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("testing pandas_udf")\
        .getOrCreate()

    arr = []
    for i in range(100000):
        arr.append([2,2,2,2,2])

    df = spark.createDataFrame(arr, ArrayType(LongType()))

    def transform(c):
        print(c)
        print(c.values)
        return c

    transform = pandas_udf(transform, ArrayType(LongType()))

    df = df.withColumn('new_value', transform(col('value')))
    df.show()

    spark.stop()

如果您查看执行者的日志,则会看到类似以下的日志:

0       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
1                                     [2, 2, 2, 2, 2]
2                                     [2, 2, 2, 2, 2]
3                                     [2, 2, 2, 2, 2]
4                                     [2, 2, 2, 2, 2]
5                                     [2, 2, 2, 2, 2]
...
9996                                  [2, 2, 2, 2, 2]
9997                                  [2, 2, 2, 2, 2]
9998                                               []
9999                                               []
Name: _0, Length: 10000, dtype: object

已解决:

如果遇到相同的问题,请升级到Spark 2.3.1和pyarrow 0.9.0.post1。

1 个答案:

答案 0 :(得分:0)

是的,看起来Spark中有错误。我的情况涉及2.3.0和PyArrow 0.13.0。对我而言,唯一可用的补救方法是将数组转换为字符串,然后将其传递给Pandas UDF。

def _identity(sample_array):
    return sample_array.apply(lambda e: [int(i) for i in e.split(',')])

array_identity_udf = F.pandas_udf(_identity,
                                  returnType=ArrayType(IntegerType()),
                                  functionType=F.PandasUDFType.SCALAR)
test_df = (spark
           .table('test_table')
           .repartition(F.ceil(F.rand(seed=1234) * 100))
           .cache())

test1_df = (test_df
           .withColumn('array_test', array_identity_udf(F.concat_ws(',', F.col('sample_array')))))