在Pyspark上,我将UDF定义如下:
from pyspark.sql.functions import udf
from scipy.spatial.distance import cdist
def closest_point(point, points):
""" Find closest point from a list of points. """
return points[cdist([point], points).argmin()]
udf_closest_point = udf(closest_point)
dfC1 = dfC1.withColumn("closest", udf_closest_point(dfC1.point, dfC1.points))
我的数据如下:
我应该为UDF更改什么来恢复浮点数而不是字符串?
答案 0 :(得分:1)
您可以将UDF的返回类型指定为浮点数组ArrayType(FloatType())
:
from pyspark.sql.types import ArrayType, FloatType
udf_closest_point = udf(closest_point, ArrayType(FloatType()))