我有一个包含以下架构的数据集:
dataset.printSchema()
|-- id: string (nullable = true)
|-- feature1: double (nullable = true)
|-- feature2: double (nullable = true)
|-- feature3: double (nullable = true)
|-- feature4: double (nullable = true)
在我的application.conf中,我定义了一个应该使用reduceByKey转换的键子集:
keyInfo {
keysToBeTransformed = "feature1,feature2"
}
我可以将这些键加载到我的主对象中:
val config : Config = ConfigFactory.load()
val keys : Array[String] = config.getString("keyInfo.keysToBeTransformed").split(",")
对于这些键,我需要计算数据集中每个id的均值,并将结果收集到一个数组中。目前,我使用以下方法:
val meanFeature1 : Array[Double] = dataset.map(x => (x.id, x.feature1)).rdd
.mapValues{z => (z,1)}
.reduceByKey{(x,y) => (x._1 + y._1, x._2 + y._2)}
.map( x => {
val temp = x._2
val total = temp._1
val count = temp._2
(x._1, total / count)
}).collect().sortBy(_._1).map(_._2),
val meanFeature2 : Array[Double] = dataset.map(x => (x.id, x.feature2)).rdd
.mapValues{z => (z,1)}
.reduceByKey{(x,y) => (x._1 + y._1, x._2 + y._2)}
.map( x => {
val temp = x._2
val total = temp._1
val count = temp._2
(x._1, total / count)
}).collect().sortBy(_._1).map(_._2)
上述方法的问题在于它没有引用我的application.conf中指定的键(当在application.conf中重新指定键时,计算不会动态改变)
我怎样才能做到这一点?
答案 0 :(得分:1)
我认为DataFrame
API在这种情况下更合适,因为它更好地支持按名称动态访问列。将Dataset
转换为DataFrame
是微不足道的:
val averagesPerId: Array[Array[Double]] = dataset
.groupBy("id") // this also converts to DataFrame
.avg(keys: _*) // create average for each key - creates a "avg(featureX)" column for each featureX key
.sort("id")
.map(r => keys.map(col => r.getAs[Double](s"avg($col)"))) // map Rows into Array[Double], one for each ID
.collect()
// transposing the result to create an array where each row relates to a single key,
// and mapping each row to its key:
val averagesPerKey: Map[String, Array[Double]] = keys.zip(averagesPerId.transpose(identity)).toMap
// for example, if `feature1` was in `keys`:
val meanFeature1 = averagesPerKey("feature1")
答案 1 :(得分:0)
我在此期间提出的另一个类似解决方案如下:
val meanFeatures : Array[Array[Double]] = keys.map(col => {dataset
.groupBy("id")
.agg(avg(col))
.as[(String,Double)]
.sort("id")
.collect().map(_._2)
})
val meanFeature1 : Array[Double] = meanFeatures(0)