我正在使用PySpark UDF在Spark worker上执行代码。如果UDF中引发了异常,则将其包装在Py4JJavaError
中并在Python中重新引发。为了正确处理错误,我需要原始错误。有没有办法从Py4JJavaError
获得它?
原始错误的字符串表示形式被打印为堆栈跟踪的一部分,因此可以通过解析跟踪至少获得错误的类型。但是,这将很乏味且容易出错。
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(pd.DataFrame({"A": [1, 2, 3]}))
@udf
def test(x):
raise ValueError(f"Got {x}")
df = df.withColumn("B", test("A"))
df.show()
我希望我可以在不解析堆栈跟踪的情况下提取最初引发的错误,或者至少提取错误的名称和/或错误消息。