使用Scala的groupBy“符合条件的偏差”

时间:2017-01-10 17:55:57

标签: scala

我有:

val a = List((1.1, 2), (1.2, 3), (3.1, 4), (2.9, 5))

我希望将此列表分组为“符合条件的偏差”,或者换句话说:将每个双倍组合使用比自身更大/更小的双精度。

我想要的结果(假设此处符合条件的偏差为0.2):

Map((1.1, 1.2) -> List((1.1, 2),(1.2, 3)), (3.1, 2.9) -> List((3.1, 4), (2.9, 5)))

我该怎么做?

2 个答案:

答案 0 :(得分:1)

不确定这是否正是您想要的:

// sort the list by the first element in each tuple
val sort_a = a.sortBy(_._1)

// calculate the difference of consecutive tuples by the first element
val diff = sort_a.scanLeft((0.0, 0.0))((x,y) => (y._1 - x._2, y._1)).tail

// create a group variable based on the difference and tolerance
val g = diff.scanLeft(0)((x, y) => if(y._1 < 0.201) x else x + 1).tail
// g: List[Int] = List(1, 1, 2, 2)

// zip the list and the group variable and split the list up by the group variable
sort_a.zip(g).groupBy(_._2).mapValues(_.map(_._1))
// res62: scala.collection.immutable.Map[Int,List[(Double, Int)]] = 
// Map(2 -> List((2.9,5), (3.1,4)), 1 -> List((1.1,2), (1.2,3)))

答案 1 :(得分:1)

这是一个(尾部)递归实现。使用scan和Collections API的主要区别在于编译器将其解压缩为一个通常运行得非常快的while循环。

import scala.annotation.tailrec

def grouper(seq: List[(Double,Int)], delta: Double): Map[List[Double], List[(Double,Int)]] = {
 @tailrec def loop(rest: List[(Double,Int)], last: Double, curGroup: List[(Double,Int)], allGroups: List[List[(Double,Int)]]): List[List[(Double,Int)]] = {
    rest match {
       case h::t  if Math.abs( h._1 - last ) <= delta => loop(t, h._1, h :: curGroup, allGroups)
       case h::t =>  loop(t, h._1, h :: Nil, if(curGroup.nonEmpty) curGroup :: allGroups else allGroups)
       case _ => if(curGroup.nonEmpty) curGroup :: allGroups else allGroups
    }
 }
 val list = loop(seq, Double.NegativeInfinity, List.empty, List.empty)
 list.map(x => (x.map(_._1), x)).toMap

}

使用它:

> grouper(List((1.1, 2), (1.2, 3), (1.3, 4), (2.9, 5)), 0.2)
res1: Map[List[Double], List[(Double, Int)]] = Map(List(2.9) -> List((2.9, 5)), List(1.3, 1.2, 1.1) -> List((1.3, 4), (1.2, 3), (1.1, 2)))