转换UDF的输出

时间:2017-11-22 18:32:22

标签: arrays casting pyspark

在Pyspark上,我将UDF定义如下:

from pyspark.sql.functions import udf
from scipy.spatial.distance import cdist

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

udf_closest_point = udf(closest_point)

dfC1 = dfC1.withColumn("closest", udf_closest_point(dfC1.point, dfC1.points))

我的数据如下:

  • point = [0.2,0.5]或[0.1,0.6] - float数组
  • points = [[0,1],[1,0],[1,1],[0,0]] - float数组的数组
  • nearest =例如,'[0,1]' - 一个字符串(这是其中一个值 从点转换为字符串)

我应该为UDF更改什么来恢复浮点数而不是字符串?

1 个答案:

答案 0 :(得分:1)

您可以将UDF的返回类型指定为浮点数组ArrayType(FloatType())

from pyspark.sql.types import ArrayType, FloatType
udf_closest_point = udf(closest_point, ArrayType(FloatType()))