在Scala中查找Spark数据帧的每一行的最大值

时间:2018-07-17 08:59:26

标签: scala apache-spark apache-spark-sql

我有一个名为spark-dataframe的输入df

+---------------+---+---+---+---+
|     CustomerID| P1| P2| P3| P4|
+---------------+---+---+---+---+
|         725153|  5|  6|  7|  8|
|         873008|  7|  8|  1|  2|
|         725116|  5|  6|  3|  2|
|         725110|  0|  1|  2|  5|
+---------------+---+---+---+---+

P1,P2,P3,P4中,我需要为每个CustomerID找到最大2个值。并得到等价的column name并放入df。这样我得到的dataframe应该是

+---------------+----+----+
|     CustomerID|col1|col2|
+---------------+----+----+
|         725153|  P4|  P3|
|         873008|  P2|  P1|
|         725116|  P2|  P1|
|         725110|  P4|  P3|
+---------------+----+----+

在第一行,87是最大值。每个等效的列名称分别为P4P3。因此,对于特定的CustomerID,它应包含值P4P3。可以通过使用pyspark数据帧在pandas中实现。

nlargest = 2
order = np.argsort(-df.values, axis=1)[:, :nlargest]
result = pd.DataFrame(df.columns[order],columns=['top{}'.format(i) for i in range(1, nlargest+1)],index=recommend_df.index)

但是如何在scala中实现这一目标?

1 个答案:

答案 0 :(得分:1)

您可以使用UDF获得所需的结果。在UDF中,您可以zip的所有列名及其各自的值,然后根据该值对Array进行排序,最后从中返回前两个列名。下面是相同的代码。

//get all the columns that you want
val requiredCol = df.columns.zipWithIndex.filter(_._2!=0).map(_._1) 
//define a UDF which sorts according to the value and returns top two column names
val topTwoColumns = udf((seq: Seq[Int]) =>
  seq.zip(requiredCol).
    sortBy(_._1)(Ordering[Int].reverse).
    take(2).map(_._2))

现在,您可以使用withColumn并将列值作为array传递到先前定义的UDF。

df.withColumn("col", topTwoColumns(array(requiredCol.map(col(_)):_*))).
  select($"CustomerID",
    $"col".getItem(0).as("col1"),
    $"col".getItem(1).as("col2")).show

//output
//+----------+----+----+
//|CustomerID|col1|col2|
//+----------+----+----+
//|    725153|  P4|  P3|
//|    873008|  P2|  P1|
//|    725116|  P2|  P1|
//|    725110|  P4|  P3|
//+----------+----+----+