如何找到多列的最大值?

时间:2019-08-16 22:26:08

标签: scala apache-spark apache-spark-sql

我正在尝试查找Spark数据帧中多列的最大值。每个列的值都为double类型。

数据框类似于:

>>> all(a > b for (a, b) in zip(A, B))
False

期望是:

+-----+---+----+---+---+
|Name | A | B  | C | D |
+-----+---+----+---+---+
|Alex |5.1|-6.2|  7|  8|
|John |  7| 8.3|  1|  2|
|Alice|  5|  46|  3|  2|
|Mark |-20| -11|-22| -5|
+-----+---+----+---+---+

3 个答案:

答案 0 :(得分:4)

您可以将greatest应用于数字列列表,如下所示:

import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import spark.implicits._

val df = Seq(
  ("Alex", 5.1, -6.2, 7.0, 8.0),
  ("John", 7.0, 8.3, 1.0, 2.0),
  ("Alice", 5.0, 46.0, 3.0, 2.0),
  ("Mark", -20.0, -11.0, -22.0, -5.0),
).toDF("Name", "A", "B", "C", "D")

val numCols = df.columns.tail  // Apply suitable filtering as needed (*)

df.withColumn("MaxValue", greatest(numCols.head, numCols.tail: _*)).
  show
// +-----+-----+-----+-----+----+--------+
// | Name|    A|    B|    C|   D|MaxValue|
// +-----+-----+-----+-----+----+--------+
// | Alex|  5.1| -6.2|  7.0| 8.0|     8.0|
// | John|  7.0|  8.3|  1.0| 2.0|     8.3|
// |Alice|  5.0| 46.0|  3.0| 2.0|    46.0|
// | Mark|-20.0|-11.0|-22.0|-5.0|    -5.0|
// +-----+-----+-----+-----+----+--------+

(*)例如,要过滤所有顶级DoubleType列:

import org.apache.spark.sql.types._

val numCols = df.schema.fields.collect{
  case StructField(name, DoubleType, _, _) => name
}

如果您使用的是Spark 2.4+,则可以选择使用array_max,尽管在这种情况下,它会涉及附加的转换步骤:

df.withColumn("MaxValue", array_max(array(numCols.map(col): _*)))

答案 1 :(得分:-1)

首先,我复制了您的df:

scala> df.show
+-----+---+----+---+---+
| Name|  A|   B|  C|  D|
+-----+---+----+---+---+
| Alex|5.1|-6.2|  7|  8|
| John|  7| 8.3|  1|  2|
|Alice|  5|  46|  3|  2|
| Mark|-20| -11|-22| -5|
+-----+---+----+---+---+

然后我将其转换为RDD并在行级别进行转换:

import scala.math.max
case class MyData(Name: String, A: Double, B: Double, C: Double, D: Double, MaxValue: Double)
val maxDF = df.rdd.map(row => {
val a = row(1).toString.toDouble
val b = row(2).toString.toDouble
val c = row(3).toString.toDouble
val d = row(4).toString.toDouble
new MyData(row(0).toString, a, b, c, d, max(max(a, b), max(c, d)))
}).toDF

这是最终输出:

maxDF.show
+-----+-----+-----+-----+----+--------+
| Name|    A|    B|    C|   D|MaxValue|
+-----+-----+-----+-----+----+--------+
| Alex|  5.1| -6.2|  7.0| 8.0|     8.0|
| John|  7.0|  8.3|  1.0| 2.0|     8.3|
|Alice|  5.0| 46.0|  3.0| 2.0|    46.0|
| Mark|-20.0|-11.0|-22.0|-5.0|    -5.0|
+-----+-----+-----+-----+----+--------+

答案 2 :(得分:-1)

您可以定义一个接收数组的UDF并返回其最大值

val getMaxColumns = udf((xs: Seq[Double]) => {
    xs.max
  })

然后创建要获取最大值(无论多少列)的列的数组

val columns = array($"A",$"B",$"C",$"D")

在您的示例中,由于您要应用所有尾列的最大值,因此可以

val columns = df.columns.tail.map(x => $"$x")

然后将withColumn与先前的udf一起应用

df.withColumn("maxValue", getMaxColumns(columns))

记住进口:

import org.apache.spark.sql.functions.{udf, array}

快速示例:

输入

df.show
+-----+-----+-----+-----+----+
| Name|    A|    B|    C|   D|
+-----+-----+-----+-----+----+
| Alex|  5.1| -6.2|  7.0| 8.0|
| John|  7.0|  8.3|  1.0| 2.0|
|Alice|  5.0| 46.0|  3.0| 2.0|
| Mark|-20.0|-11.0|-22.0|-5.0|
+-----+-----+-----+-----+----+

输出

df.withColumn("maxValue", getMaxColumns(columns)).show
+-----+-----+-----+-----+----+--------+
| Name|    A|    B|    C|   D|maxValue|
+-----+-----+-----+-----+----+--------+
| Alex|  5.1| -6.2|  7.0| 8.0|     8.0|
| John|  7.0|  8.3|  1.0| 2.0|     8.3|
|Alice|  5.0| 46.0|  3.0| 2.0|    46.0|
| Mark|-20.0|-11.0|-22.0|-5.0|    -5.0|
+-----+-----+-----+-----+----+--------+