我有一个具有以下结构的Spark DataFrame:
root
|-- distribution: vector (nullable = true)
+--------------------+
| topicDistribution|
+--------------------+
| [0.1, 0.2] |
| [0.3, 0.2] |
| [0.5, 0.2] |
| [0.1, 0.7] |
| [0.1, 0.8] |
| [0.1, 0.9] |
+--------------------+
我的问题是:如何添加每列最大值索引的列?
它应该是这样的:
root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)
+--------------------+-----------+
| topicDistribution| max_index |
+--------------------+-----------+
| [0.1, 0.2] | 1 |
| [0.3, 0.2] | 0 |
| [0.5, 0.2] | 0 |
| [0.1, 0.7] | 1 |
| [0.1, 0.8] | 1 |
| [0.1, 0.9] | 1 |
+--------------------+-----------+
非常感谢
我尝试了以下方法,但出现了错误:
import org.apache.spark.sql.functions.udf
val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )
df.withColumn("max_idx",func(($"topicDistribution"))).show()
错误说:
Exception in thread "main" org.apache.spark.sql.AnalysisException:
cannot resolve 'UDF(topicDistribution)' due to data type mismatch:
argument 1 requires array<double> type, however, '`topicDistribution`'
is of vector type.;;
答案 0 :(得分:2)
// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)
val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
| topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+
// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )
mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
| topics|max_idx|
+----------+-------+
|[0.1, 0.2]| 1|
|[0.6, 0.2]| 0|
+----------+-------+
//注意:对于您的特定用例,您可能必须将UDF更改为Vector而不是Seq //编辑为使用Vector而不是Seq作为原始问题并且您的评论被要求
答案 1 :(得分:1)
注意:解决方案可能不是最佳性能,但只是展示了解决问题的另一种方法(并显示了Spark SQL的数据集API的丰富程度)。
vector
来自Spark MLlib的VectorUDT
,所以让我先创建一个示例数据集。
val input = Seq((0.1, 0.2), (0.3, 0.2)).toDF
import org.apache.spark.ml.feature.VectorAssembler
val vecAssembler = new VectorAssembler()
.setInputCols(Array("_1", "_2"))
.setOutputCol("distribution")
val ds = vecAssembler.transform(input).select("distribution")
scala> ds.printSchema
root
|-- distribution: vector (nullable = true)
架构看起来与你的完全一样。
让我们将类型从VectorUDT
更改为常规Array[Double]
。
import org.apache.spark.ml.linalg.Vector
val arrays = ds
.map { r => r.getAs[Vector](0).toArray }
.withColumnRenamed("value", "distribution")
scala> arrays.printSchema
root
|-- distribution: array (nullable = true)
| |-- element: double (containsNull = false)
使用arrays
,您可以使用posexplode
索引数组中的元素,groupBy
到max
的位置和join
的解决方案。
val pos = arrays.select($"*", posexplode($"distribution"))
val max_cols = pos
.groupBy("distribution")
.agg(max("col") as "max_col")
val solution = pos
.join(max_cols, "distribution")
.filter($"col" === $"max_col")
.select("distribution", "pos")
scala> solution.show
+------------+---+
|distribution|pos|
+------------+---+
| [0.1, 0.2]| 1|
| [0.3, 0.2]| 0|
+------------+---+