我需要一个 udf 函数来输入数据帧的数组列并对其中的两个字符串元素执行相等性检查。我的数据框有这样的架构。
ID | 日期 | 选项 |
---|---|---|
1 | 2021-01-06 | ['red', 'green'] |
2 | 2021-01-07 | ['蓝色', '蓝色'] |
3 | 2021-01-08 | ['蓝色', '黄色'] |
4 | 2021-01-09 | nan |
我已经试过了:
def equality_check(options: list):
try:
if options[0] == options[1]:
return 1
else:
return 0
except:
return -1
equality_udf = f.udf(equality_check, t.IntegerType())
但它抛出了索引错误。我相信 options 列是字符串数组。 期望是这样的:
ID | 日期 | 选项 | equality_check |
---|---|---|---|
1 | 2021-01-06 | ['red', 'green'] | 0 |
2 | 2021-01-07 | ['蓝色', '蓝色'] | 1 |
3 | 2021-01-08 | ['蓝色', '黄色'] | 0 |
4 | 2021-01-09 | nan | -1 |
答案 0 :(得分:1)
您可以检查 options
列表是否已定义或其长度是否小于 2,而不是使用 try/except。这是一个工作示例:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
data = [
(1, "2021-01-06", ['red', 'green']),
(2, "2021-01-07", ['Blue', 'Blue']),
(3, "2021-01-08", ['Blue', 'Yellow']),
(4, "2021-01-09", None),
]
df = spark.createDataFrame(data, ["ID", "date", "options"])
def equality_check(options: list):
if not options or len(options) < 2:
return -1
return int(options[0] == options[1])
equality_udf = F.udf(equality_check, IntegerType())
df1 = df.withColumn("equality_check", equality_udf(F.col("options")))
df1.show()
#+---+----------+--------------+--------------+
#| ID| date| options|equality_check|
#+---+----------+--------------+--------------+
#| 1|2021-01-06| [red, green]| 0|
#| 2|2021-01-07| [Blue, Blue]| 1|
#| 3|2021-01-08|[Blue, Yellow]| 0|
#| 4|2021-01-09| null| -1|
#+---+----------+--------------+--------------+
但是,我建议你不要使用 UDF,因为你可以只使用内置函数来做同样的事情:
df1 = df.withColumn(
"equality_check",
F.when(F.size(F.col("options")) < 2, -1)
.when(F.col("options")[0] == F.col("options")[1], 1)
.otherwise(0)
)