我正在尝试使用pandas_udf
。
我有我的Spark DataFrame,其中有一个Struct的数组列:
root
|-- values: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- x: double (nullable = true)
| | |-- y: double (nullable = true)
我想做的是在values
列上运行pandas_udf,并为每个记录(这意味着每个数组)基于定义的逻辑返回一个值。
在数据源中,我有多个记录,其中列values
包含空数组时,出现此错误:
line 89, in verify_result_length "expected %d, got %d" % (len(a[0]), len(result)))
RuntimeError: Result vector from pandas_udf was not the required length: expected 2, got 1
在Google上搜索时,我在这里找到了源代码:https://github.com/apache/spark/blob/master/python/pyspark/worker.py
另一方面,如果我只有一个记录为空数组,则该过程将毫无问题地结束。在Spark中使用空数组过滤记录无济于事(FYI .filter("size(values) != 0")
),这会导致相同的行为。
但是我不明白为什么会出现此错误以及我在做什么错。有人可以帮忙吗?
编辑
代码:
import numpy as np
import pandas as pd
import pyspark.sql.functions as f
udf = f.pandas_udf(udf_function, returnType="float")
df.withColumn("newColumn",
f.when(f.size(f.col("values")) == 0, 0)\
.otherwise(udf(f.col("values.x"), f.col("values.y"))))
def udf_function(x_array, y_array, Window=20):
xy = []
data = pd.DataFrame({'x': x_array.tolist()[0], 'y': y_array.tolist()[0]})
x = data['x'].rolling(int(Window / 2), center=True, min_periods=1).sum()
y = data['y'].rolling(int(Window / 2), center=True, min_periods=1).sum()
for index in range(len(x)):
xy.append(np.sqrt(x[index] ** 2 + y[index] ** 2))
if not xy:
result = 0
else:
result = max(xy)
return pd.Series(result)
记录:
{
"values": []
}
{
"values": []
}
{
"values": [
{
"x": -0.638,
"y": 0.879,
},
{
"x": -0.616,
"y": 0.809,
},
{
"x": -0.936,
"y": 0.762,
}]
}
P.S。即使应该使用f.when(f.size(f.col("values")) == 0, 0).otherwise(udf() )
条件来过滤出具有空数组的记录(或者至少在不存在空数组的情况下,不应该调用udf函数),但无论如何,udf函数看起来都可以处理这些记录记录! D:这对我来说很奇怪
更多内容:
当我使用大约15个或更多输入文件时引发了异常,这与使用空数组的记录没有关系。无论如何,这是一个怪异的异常,因为在观察到的行为下,我期望像“ JavaOutOfMemory Exception”之类的东西,但是所报告的异常并不能帮助我了解真正的问题。
同时,我回来了,使用的是更稳定的RDD而不是pandas_udf。