在Spark数据框中的多列上执行算术运算

时间:2018-06-29 14:06:44

标签: scala apache-spark apache-spark-sql

我有一个名为spark-dataframe的输入df

+---------------+---+---+---+-----------+
|Main_CustomerID| P1| P2| P3|Total_Count|
+---------------+---+---+---+-----------+
|         725153|  1|  0|  2|          3|
|         873008|  0|  0|  3|          3|
|         625109|  1|  1|  0|          2|
+---------------+---+---+---+-----------+

在这里,Total_CountP1,P2,P3的总和,而P1,P2,P3product names的总和。我需要通过将产品的值除以其frequency来找到每个产品的Total_Count。我需要创建一个名为spark-dataframe的新frequencyTable,如下所示

+---------------+------------------+---+------------------+-----------+
|Main_CustomerID|                P1| P2|                P3|Total_Count|
+---------------+------------------+---+------------------+-----------+
|         725153|0.3333333333333333|0.0|0.6666666666666666|          3|
|         873008|               0.0|0.0|               1.0|          3|
|         625109|               0.5|0.5|               0.0|          2|
+---------------+------------------+---+------------------+-----------+

我已经使用Scala作为

val df_columns = df.columns.toSeq
var frequencyTable = df
for (index <- df_columns) {
  if (index != "Main_CustomerID" && index != "Total_Count") {
  frequencyTable = frequencyTable.withColumn(index, df.col(index) / df.col("Total_Count"))
}
}

但是我不喜欢此for循环,因为我的df的尺寸较大。什么是优化的解决方案?

1 个答案:

答案 0 :(得分:1)

如果数据框为

val df = Seq(
  ("725153", 1, 0, 2, 3),
  ("873008", 0, 0, 3, 3),
  ("625109", 1, 1, 0, 2)
).toDF("Main_CustomerID", "P1", "P2", "P3", "Total_Count")

+---------------+---+---+---+-----------+
|Main_CustomerID|P1 |P2 |P3 |Total_Count|
+---------------+---+---+---+-----------+
|725153         |1  |0  |2  |3          |
|873008         |0  |0  |3  |3          |
|625109         |1  |1  |0  |2          |
+---------------+---+---+---+-----------+

您可以简单地在foldLeftMain_CustomerID以外的列上使用Total_Count,即在P1 P2P3上使用

val df_columns = df.columns.toSet - "Main_CustomerID" - "Total_Count" toList

df_columns.foldLeft(df){(tempdf, colName) => tempdf.withColumn(colName, df.col(colName) / df.col("Total_Count"))}.show(false)

应该给您

+---------------+------------------+---+------------------+-----------+
|Main_CustomerID|P1                |P2 |P3                |Total_Count|
+---------------+------------------+---+------------------+-----------+
|725153         |0.3333333333333333|0.0|0.6666666666666666|3          |
|873008         |0.0               |0.0|1.0               |3          |
|625109         |0.5               |0.5|0.0               |2          |
+---------------+------------------+---+------------------+-----------+

我希望答案会有所帮助