如何计算Spark数据集中的平均向量w。 Scala呢?

时间:2018-05-14 14:10:12

标签: scala apache-spark

假设我在Apache Spark中有一个数据集,如下所示:

+---+--------------------+
| id|                 vec|
+---+--------------------+
|  0|[1, 2, 3, 4]        |
|  0|[2, 3, 4, 5]        |
|  0|[6, 7, 8, 9]        |
|  1|[1, 2, 3, 4]        |
|  1|[5, 6, 7, 8]        |
+---+--------------------+

vec是List的{​​{1}}。

如何从中创建一个包含id的数据集以及与该id相关联的向量的平均值,如下所示:

Doubles

提前致谢!

1 个答案:

答案 0 :(得分:0)

创建一个case类来匹配DataSet的输入模式。 按ID分组数据集,并使用foldLeft累积分组数据集的向量中每个idx的平均值。

scala> case class Test(id: Int, vec: List[Double])
defined class Test

scala> val inputList = List(
     |   Test(0, List(1, 2, 3, 4)),
     |   Test(0, List(2, 3, 4, 5)),
     |   Test(0, List(6, 7, 8, 9)),
     |   Test(1, List(1, 2, 3, 4)),
     |   Test(1, List(5, 6, 7, 8)))
inputList: List[Test] = List(Test(0,List(1.0, 2.0, 3.0, 4.0)), Test(0,List(2.0, 3.0, 4.0, 5.0)), Test(0,List(6.0, 7.0, 8.0, 9.0)), Test(1,
List(1.0, 2.0, 3.0, 4.0)), Test(1,List(5.0, 6.0, 7.0, 8.0)))

scala>

scala> import spark.implicits._
import spark.implicits._

scala> val ds = inputList.toDF.as[Test]
ds: org.apache.spark.sql.Dataset[Test] = [id: int, vec: array<double>]

scala> ds.show(false)
+---+--------------------+
|id |vec                 |
+---+--------------------+
|0  |[1.0, 2.0, 3.0, 4.0]|
|0  |[2.0, 3.0, 4.0, 5.0]|
|0  |[6.0, 7.0, 8.0, 9.0]|
|1  |[1.0, 2.0, 3.0, 4.0]|
|1  |[5.0, 6.0, 7.0, 8.0]|
+---+--------------------+


scala>

scala> val outputDS = ds.groupByKey(_.id).mapGroups {
     |   case (key, valuePairs) =>
     |     val vectors = valuePairs.map(_.vec).toArray
     |     // compute the length of the vectors for each key
     |     val len = vectors.length
     |     // get average for each index in vectors
     |     val avg = vectors.head.indices.foldLeft(List[Double]()) {
     |       case (acc, idx) =>
     |         val sumOfIdx = vectors.map(_ (idx)).sum
     |         acc :+ (sumOfIdx / len)
     |     }
     |     Test(key, avg)
     | }
outputDS: org.apache.spark.sql.Dataset[Test] = [id: int, vec: array<double>]

scala> outputDS.show(false)
+---+--------------------+
|id |vec                 |
+---+--------------------+
|1  |[3.0, 4.0, 5.0, 6.0]|
|0  |[3.0, 4.0, 5.0, 6.0]|
+---+--------------------+

希望这有帮助!