我已经在spark-shell中编写了Scala代码,以将数据框的一列映射到另一列。我现在正在尝试将其转换为Java,但是在使用我定义的UDF时遇到了困难。
我正在处理此数据框:
+------+-----+-----+
|acctId|vehId|count|
+------+-----+-----+
| 1| 777| 3|
| 2| 777| 1|
| 1| 666| 1|
| 1| 999| 3|
| 1| 888| 2|
| 3| 777| 4|
| 2| 999| 1|
| 3| 888| 2|
| 2| 888| 3|
+------+-----+-----+
并将其转换为此:
+------+----------------------------------------+
|acctId|vehIdToCount |
+------+----------------------------------------+
|1 |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3 |[777 -> 4, 888 -> 2] |
|2 |[777 -> 1, 999 -> 1, 888 -> 3] |
+------+----------------------------------------+
我正在通过这些命令执行此操作。 首先,我的UDF将行值列表从一列映射到第二列:
val listToMap = udf((input: Seq[Row]) => input.map(row => (row.getAs[Long](0), row.getAs[Long](1))).toMap)
我通过双重分组/聚合来做到这一点:
val resultDF = testData.groupBy("acctId", "vehId")
.agg(count("acctId").cast("long").as("count"))
.groupBy("acctId")
.agg(collect_list(struct("vehId", "count")) as ("vehIdToCount"))
.withColumn("vehIdToCount", listToMap($"map"))
我的问题是试图用Java编写listToMap UDF。我对Scala和Java都是陌生的,所以我可能只是缺少一些东西。
我希望可以做些简单的事情:
UserDefinedFunction listToMap = udf(
(Seq<Dataset<Row>> input) -> input.map(r -> (r.get(“vehicleId”), r.get(“count”)));
);
但是,即使仔细阅读了文档,我也无法确定一种有效的方法来获取每一列。我也尝试过仅执行SELECT,但这也不起作用。
非常感谢您的帮助。供您参考,这是我如何在spark-shell中生成测试数据的方法:
val testData = Seq(
(1, 999),
(1, 999),
(2, 999),
(1, 888),
(2, 888),
(3, 888),
(2, 888),
(2, 888),
(1, 888),
(1, 777),
(1, 666),
(3, 888),
(1, 777),
(3, 777),
(2, 777),
(3, 777),
(3, 777),
(1, 999),
(3, 777),
(1, 777)
).toDF("acctId", "vehId”)
答案 0 :(得分:1)
我无法帮助您编写UDF,但是我可以向您展示如何使用Spark的内置map_from_entries
函数来避免使用UDF。 UDF应该始终是最后的求助之路,既要保持代码库简单,又要因为Spark无法对其进行优化。下面的示例在Scala中,但翻译起来很简单:
scala> val testData = Seq(
| (1, 999),
| (1, 999),
| (2, 999),
| (1, 888),
| (2, 888),
| (3, 888),
| (2, 888),
| (2, 888),
| (1, 888),
| (1, 777),
| (1, 666),
| (3, 888),
| (1, 777),
| (3, 777),
| (2, 777),
| (3, 777),
| (3, 777),
| (1, 999),
| (3, 777),
| (1, 777)
| ).toDF("acctId", "vehId")
testData: org.apache.spark.sql.DataFrame = [acctId: int, vehId: int]
scala>
scala> val withMap = testData.groupBy('acctId, 'vehId).
| count.
| select('acctId, struct('vehId, 'count).as("entries")).
| groupBy('acctId).
| agg(map_from_entries(collect_list('entries)).as("myMap"))
withMap: org.apache.spark.sql.DataFrame = [acctId: int, myMap: map<int,bigint>]
scala>
scala> withMap.show(false)
+------+----------------------------------------+
|acctId|myMap |
+------+----------------------------------------+
|1 |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3 |[777 -> 4, 888 -> 2] |
|2 |[777 -> 1, 999 -> 1, 888 -> 3] |
+------+----------------------------------------+