如何在向量列中找到最大值的索引?

时间:2017-11-29 19:30:00

标签: scala apache-spark dataframe apache-spark-sql

我有一个具有以下结构的Spark DataFrame:

root
|-- distribution: vector (nullable = true)

+--------------------+
|   topicDistribution|
+--------------------+
|     [0.1, 0.2]     |
|     [0.3, 0.2]     |
|     [0.5, 0.2]     |
|     [0.1, 0.7]     |
|     [0.1, 0.8]     |
|     [0.1, 0.9]     |
+--------------------+

我的问题是:如何添加每列最大值索引的列?

它应该是这样的:

root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)

+--------------------+-----------+
|   topicDistribution| max_index |
+--------------------+-----------+
|     [0.1, 0.2]     |   1       | 
|     [0.3, 0.2]     |   0       | 
|     [0.5, 0.2]     |   0       | 
|     [0.1, 0.7]     |   1       | 
|     [0.1, 0.8]     |   1       | 
|     [0.1, 0.9]     |   1       | 
+--------------------+-----------+

非常感谢

我尝试了以下方法,但出现了错误:

import org.apache.spark.sql.functions.udf

val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )

df.withColumn("max_idx",func(($"topicDistribution"))).show()

错误说:

Exception in thread "main" org.apache.spark.sql.AnalysisException: 
cannot resolve 'UDF(topicDistribution)' due to data type mismatch: 
argument 1 requires array<double> type, however, '`topicDistribution`' 
is of vector type.;;

2 个答案:

答案 0 :(得分:2)

// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)

 val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
|    topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+

// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )


mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
|    topics|max_idx|
+----------+-------+
|[0.1, 0.2]|      1|
|[0.6, 0.2]|      0|
+----------+-------+

//注意:对于您的特定用例,您可能必须将UDF更改为Vector而不是Seq //编辑为使用Vector而不是Seq作为原始问题并且您的评论被要求

答案 1 :(得分:1)

注意:解决方案可能不是最佳性能,但只是展示了解决问题的另一种方法(并显示了Spark SQL的数据集API的丰富程度)。

vector来自Spark MLlib的VectorUDT,所以让我先创建一个示例数据集。

val input = Seq((0.1, 0.2), (0.3, 0.2)).toDF
import org.apache.spark.ml.feature.VectorAssembler
val vecAssembler = new VectorAssembler()
  .setInputCols(Array("_1", "_2"))
  .setOutputCol("distribution")
val ds = vecAssembler.transform(input).select("distribution")
scala> ds.printSchema
root
 |-- distribution: vector (nullable = true)

架构看起来与你的完全一样。

让我们将类型从VectorUDT更改为常规Array[Double]

import org.apache.spark.ml.linalg.Vector
val arrays = ds
  .map { r => r.getAs[Vector](0).toArray }
  .withColumnRenamed("value", "distribution")
scala> arrays.printSchema
root
 |-- distribution: array (nullable = true)
 |    |-- element: double (containsNull = false)

使用arrays,您可以使用posexplode索引数组中的元素,groupBymax的位置和join的解决方案。

val pos = arrays.select($"*", posexplode($"distribution"))
val max_cols = pos
  .groupBy("distribution")
  .agg(max("col") as "max_col")
val solution = pos
  .join(max_cols, "distribution")
  .filter($"col" === $"max_col")
  .select("distribution", "pos")
scala> solution.show
+------------+---+
|distribution|pos|
+------------+---+
|  [0.1, 0.2]|  1|
|  [0.3, 0.2]|  0|
+------------+---+