Scala / Spark数据帧:找到与max相对应的列名

时间:2017-02-27 11:51:30

标签: scala apache-spark dataframe apache-spark-sql argmax

在Scala / Spark中,有一个数据帧:

val dfIn = sqlContext.createDataFrame(Seq(
  ("r0", 0, 2, 3),
  ("r1", 1, 0, 0),
  ("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")

我想计算一个新列maxCol,其中包含与最大值对应的列的名称(对于每一行)。在此示例中,输出应为:

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

实际上,数据框有超过60列。因此,需要通用的解决方案。

Python Pandas中的等价物(是的,我知道,我应该与pyspark进行比较......)可能是:

dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1) 

1 个答案:

答案 0 :(得分:10)

通过小技巧,您可以使用SELECT COUNT(*) View_Name功能。必需的进口:

greatest

首先,让我们创建一个import org.apache.spark.sql.functions.{col, greatest, lit, struct} 列表,其中第一个元素是值,第二个列名称是:

structs

这样的结构可以传递给val structs = dfIn.columns.tail.map( c => struct(col(c).as("v"), lit(c).as("k")) ) ,如下所示:

greatest
dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))

请注意,如果是关系,它将采用序列中稍后出现的元素(按字典顺序+---+---+---+---+------+ | id| c0| c1| c2|maxCol| +---+---+---+---+------+ | r0| 0| 2| 3| c2| | r1| 1| 0| 0| c0| | r2| 0| 2| 2| c2| +---+---+---+---+------+ )。如果由于某种原因这是不可接受的,您可以使用(x, "c2") > (x, "c1")明确减少:

when
import org.apache.spark.sql.functions.when

val max_col = structs.reduce(
  (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")

dfIn.withColumn("maxCol", max_col)

如果是+---+---+---+---+------+ | id| c0| c1| c2|maxCol| +---+---+---+---+------+ | r0| 0| 2| 3| c2| | r1| 1| 0| 0| c0| | r2| 0| 2| 2| c1| +---+---+---+---+------+ 列,则必须对此进行调整,例如nullable将值调整为coalescing