scala:删除所有列的列值低于中值的列

时间:2017-09-12 06:54:35

标签: scala apache-spark

在分析的数据缩减阶段,我想删除列总数低于所有列总数中值的所有列。 所以使用数据集:

compareWith

我总结列

<md-select multiple="true" [(ngModel)]="test" [compareWith]="compareWithFunc">
  <md-option *ngFor="let l of list" [value]="l">
    {{l.name}}
  </md-option>
</md-select>

compareWithFunc(a, b) {
  return a.name === b.name;
}

中位数为7所以我放弃v1

v1,v2,v3
1  3  5
3  4  3

我以为我可以在Row上使用流功能执行此操作。但这似乎不可能。

我提出的代码是有效的,但它看起来非常冗长,看起来很像Java代码(我认为这是我做错了的一个标志)。

有没有更有效的方法来执行此操作?

v1,v2,v3
4  7  8

1 个答案:

答案 0 :(得分:1)

在计算Spark中每列的总和后,我们可以在普通Scala中获得中值,然后通过列索引<仅选择大于或等于此值的列/ em>的。

让我们从定义计算中位数的函数开始,它只是this example的略微修改:

def median(seq: Seq[Long]): Long = {
  //In order if you are not sure that 'seq' is sorted
  val sortedSeq = seq.sortWith(_ < _)

  if (seq.size % 2 == 1) sortedSeq(sortedSeq.size / 2)
  else {
    val (up, down) = sortedSeq.splitAt(seq.size / 2)
    (up.last + down.head) / 2
  }
}

我们首先计算所有列的总和并将其转换为Seq[Long]

import org.apache.spark.sql.functions._ 
val sums = df.select(df.columns.map(c => sum(col(c)).alias(c)): _*)
             .first.toSeq.asInstanceOf[Seq[Long]]

然后我们计算median

val med = median(sums)

并将其用作生成列索引的阈值以保持:

val cols_keep = sums.zipWithIndex.filter(_._1 >= med).map(_._2)

最后,我们将这些索引映射到select()语句中:

df.select(cols_keep map df.columns map col: _*).show()
+---+---+
| v2| v3|
+---+---+
|  3|  5|
|  4|  3|
+---+---+