我创建了一个pyspark数据框,如下所示:
df = spark.createDataFrame([([0.1,0.2], 2), ([0.1], 3), ([0.3,0.3,0.4], 2)], ("a", "b"))
df.show()
+---------------+---+
| a| b|
+---------------+---+
| [0.1, 0.2]| 2|
| [0.1]| 3|
|[0.3, 0.3, 0.4]| 2|
+---------------+---+
现在,我正尝试一次解析“ a”列,如下所示:
parse_col = udf(lambda row: [ x for x in row.a], ArrayType(FloatType()))
new_df = df.withColumn("a_new", parse_col(struct([df[x] for x in df.columns if x == 'a'])))
new_df.show()
这很好。
+---------------+---+---------------+
| a| b| a_new|
+---------------+---+---------------+
| [0.1, 0.2]| 2| [0.1, 0.2]|
| [0.1]| 3| [0.1]|
|[0.3, 0.3, 0.4]| 2|[0.3, 0.3, 0.4]|
+---------------+---+---------------+
但是当我尝试格式化值时,如下所示:
count_empty_columns = udf(lambda row: ["{:.2f}".format(x) for x in row.a], ArrayType(FloatType()))
new_df = df.withColumn("a_new", count_empty_columns(struct([df[x] for x in df.columns if x == 'a'])))
new_df.show()
它不起作用-值丢失了
+---------------+---+-----+
| a| b|a_new|
+---------------+---+-----+
| [0.1, 0.2]| 2| [,]|
| [0.1]| 3| []|
|[0.3, 0.3, 0.4]| 2| [,,]|
+---------------+---+-----+
我正在使用Spark v2.3.1
你知道我在做什么错吗?
谢谢
答案 0 :(得分:1)
很简单-类型很重要。您将输出声明为array<string>
,但格式化的字符串不是一个。因此结果是不确定的。换句话说,字符串和浮点数是互斥的。
如果需要字符串,则应这样声明列
udf(lambda row: ["{:.2f}".format(x) for x in row.a], "array<string>")
否则,您应考虑四舍五入或使用固定精度的数字。
df.select(df["a"].cast("array<decimal(38, 2)>")).show()
+------------------+
| a|
+------------------+
| [0.10, 0.20]|
| [0.10]|
|[0.30, 0.30, 0.40]|
+------------------+
尽管这些是完全不同的操作。