我对scala和spark编程还比较陌生。
我有一个用例,我需要根据某些列对数据进行分组并具有某个列的计数(使用数据透视),然后最终需要在平面数据框中创建一个嵌套的数据框。
我面临的一个主要挑战是我还需要保留其他一些专栏(而不是我正在探讨的专栏)。
我无法找到一种有效的方法。
输入
ID ID2 ID3 country items_purchased quantity
1 1 1 UK apple 1
1 1 1 USA mango 1
1 2 3 China banana 3
2 1 1 UK mango 1
现在说,我想对“国家”进行透视,然后对(“ ID”,“ ID2”,“ ID3”)进行分组 但我也想将其他列保留为列表。
例如
OUTPUT-1:
ID ID2 ID3 UK USA China items_purchased quantity
1 1 1 1 1 0 [apple,mango] [1,1]
1 2 3 0 0 1 [banana] [3]
2 1 1 1 0 0 [mango] [1]
实现这一目标
我想将其嵌套到嵌套结构中,以使我的架构看起来像:
{
"ID" : 1,
"ID2" : 1,
"ID3" : 1,
"countries" :
{
"UK" : 1,
"USA" : 1,
"China" : 0,
},
"items_purchased" : ["apple", "mango"],
"quantity" : [1,1]
}
我相信我可以使用案例类,然后将数据框的每一行映射到它。但是,我不确定这是否有效。我很想知道是否有更优化的方法来实现这一目标。
我想到的是这些行上的内容:
dataframe.map(row => myCaseClass(row.getAs[Long]("ID"),
row.getAs[Long]("ID2"),
row.getAs[Long]("ID3"),
CountriesCaseClass(
row.getAs[String]("UK")
)
)
以此类推...
答案 0 :(得分:1)
我认为这应该适合您的情况。分区号是根据公式partitions_num = data_size / 500MB
计算的。
import org.apache.spark.sql.functions.{collect_list, count, col, lit, map}
val data = Seq(
(1, 1, 1, "UK", "apple", 1),
(1, 1, 1, "USA","mango", 1),
(1, 2, 3, "China", "banana", 3),
(2, 1, 1, "UK", "mango", 1))
// e.g: partitions_num = 100GB / 500MB = 200, adjust it according to the size of your data
val partitions_num = 250
val df = data.toDF("ID", "ID2", "ID3", "country", "items_purchased", "quantity")
.repartition(partitions_num, $"ID", $"ID2", $"ID3") //the partition should remain the same for all the operations
.persist()
//get countries, we will need it to fill with 0 the null values after pivoting, for the mapping and for the drop
val countries = df.select("country").distinct.collect.map{_.getString(0)}
//creates a sequence of key/value which should be the input for the map function
val countryMapping = countries.flatMap{c => Seq(lit(c), col(c))}
val pivotCountriesDF = df.select("ID", "ID2", "ID3", "country")
.groupBy("ID", "ID2", "ID3")
.pivot($"country")
.count()
.na.fill(0, countries)
.withColumn("countries", map(countryMapping:_*))//i.e map("UK", col("UK"), "China", col("China")) -> {"UK":0, "China":1}
.drop(countries:_*)
// pivotCountriesDF.rdd.getNumPartitions == 250, Spark will retain the partition number since we didnt change the partition key
// +---+---+---+-------------------------------+
// |ID |ID2|ID3|countries |
// +---+---+---+-------------------------------+
// |1 |2 |3 |[China -> 1, USA -> 0, UK -> 0]|
// |1 |1 |1 |[China -> 0, USA -> 1, UK -> 1]|
// |2 |1 |1 |[China -> 0, USA -> 0, UK -> 1]|
// +---+---+---+-------------------------------+
val listDF = df.select("ID", "ID2", "ID3", "items_purchased", "quantity")
.groupBy("ID", "ID2", "ID3")
.agg(
collect_list("items_purchased").as("items_purchased"),
collect_list("quantity").as("quantity"))
// +---+---+---+---------------+--------+
// |ID |ID2|ID3|items_purchased|quantity|
// +---+---+---+---------------+--------+
// |1 |2 |3 |[banana] |[3] |
// |1 |1 |1 |[apple, mango] |[1, 1] |
// |2 |1 |1 |[mango] |[1] |
// +---+---+---+---------------+--------+
// listDF.rdd.getNumPartitions == 250, to validate this try to change the partition key with .groupBy("ID", "ID2") it will fall back to the default 200 value of spark.sql.shuffle.partitions setting
val joinedDF = pivotCountriesDF.join(listDF, Seq("ID", "ID2", "ID3"))
// joinedDF.rdd.getNumPartitions == 250, the same partitions will be used for the join as well.
// +---+---+---+-------------------------------+---------------+--------+
// |ID |ID2|ID3|countries |items_purchased|quantity|
// +---+---+---+-------------------------------+---------------+--------+
// |1 |2 |3 |[China -> 1, USA -> 0, UK -> 0]|[banana] |[3] |
// |1 |1 |1 |[China -> 0, USA -> 1, UK -> 1]|[apple, mango] |[1, 1] |
// |2 |1 |1 |[China -> 0, USA -> 0, UK -> 1]|[mango] |[1] |
// +---+---+---+-------------------------------+---------------+--------+
joinedDF.toJSON.show(false)
// +--------------------------------------------------------------------------------------------------------------------+
// |value |
// +--------------------------------------------------------------------------------------------------------------------+
// |{"ID":1,"ID2":2,"ID3":3,"countries":{"China":1,"USA":0,"UK":0},"items_purchased":["banana"],"quantity":[3]} |
// |{"ID":1,"ID2":1,"ID3":1,"countries":{"China":0,"USA":1,"UK":1},"items_purchased":["apple","mango"],"quantity":[1,1]}|
// |{"ID":2,"ID2":1,"ID3":1,"countries":{"China":0,"USA":0,"UK":1},"items_purchased":["mango"],"quantity":[1]} |
// +--------------------------------------------------------------------------------------------------------------------++
祝您好运,如果您需要任何澄清,请告诉我。
答案 1 :(得分:0)
我没看到任何问题,这是一个很好的解决方案。无论如何,我将创建一个“数据集for your final
数据框”。工作更轻松。
val ds: Dataset[myCaseClass] = dataframe.map(row => myCaseClass(row.getAs[Long]("ID"),
...
编辑 您要求这样的东西。
input.groupby("ID","ID2","ID3")
.withColumn("UK", col("country").contains("UK"))
.withColumn("China", col("country").contains("China"))
.withColumnRenamed("country","USA", col("country").contains("USA"))