以任何方式防止Scala TreeSet替换等效值

时间:2014-09-08 04:19:48

标签: scala scala-collections

使用Scala 2.10.2 - 我想构建一组有序的浮点值,并根据指定的容差进行比较,如下例所示:

implicit object DiffAtLeastOne extends Ordering[Double] {
  def compare(a: Double, b: Double): Int =
    if ((a - b).abs < 1.0) 0
    else a.compare(b)
}


val ts = scala.collection.mutable.TreeSet.empty[Double]
ts += 0.0
ts += 0.9
ts += 1.8
println( ts )  // prints TreeSet(1.8)

我曾预料到,因为连续的值被视为等价,所以该集将保留第一个(0.0)和最后一个(1.8),而是用下一个值替换每个先前的值。有没有一种简单的方法可以防止这种情况而不对TreeSet进行子类化(它给出了弃用警告)?

如果这是重复的道歉 - 我拖了一段时间没有成功。

更新

在查看this related answer和@ user2864740的评论后,我意识到我需要使我的值类型的equals方法符合Ordering.compare方法。以下是丑陋的,但给出了我追求的行为:

case class RoundedDouble(value: Double) {
  import RoundedDouble._

  override def equals(other: Any): Boolean =
    if (other.isInstanceOf[RoundedDouble]) {
      val otherRD = other.asInstanceOf[RoundedDouble]
      DiffAtLeastOne.compare(this, otherRD) == 0
    }
    else false
}

object RoundedDouble {
  implicit object DiffAtLeastOne extends Ordering[RoundedDouble] {
    def compare(a: RoundedDouble, b: RoundedDouble): Int =
      if ((a.value - b.value).abs < 1.0) 0
      else a.value.compare(b.value)
  }  

  implicit def fromDouble(d: Double) = RoundedDouble(d)
}


val ts = scala.collection.mutable.TreeSet.empty[RoundedDouble]
ts += 0.0
ts += 0.9
ts += 1.8

println( ts )  // prints TreeSet(RoundedDouble(0.0), RoundedDouble(1.8))

毫无疑问,有一种更优雅的选择。

1 个答案:

答案 0 :(得分:0)

回答我自己的问题(但仍希望得到更好的答案)......

以下构图解决方案(受this answer启发)似乎可以满足我的需求:

import scala.collection.mutable.{SortedSet, TreeSet}

object RoundedDoubleSortedSet {
  val Tolerance = 1.0e-8

  def almostEq(a: Double, b: Double) =
    (a - b).abs < Tolerance

  val ordering = new Ordering[Double] {
    def compare(a: Double, b: Double): Int =
      if (almostEq(a, b)) 0
      else a.compare(b)
  }
}

final class RoundedDoubleSortedSet extends SortedSet[Double] {

  override val ordering: Ordering[Double] = RoundedDoubleSortedSet.ordering 
  private val ts = TreeSet.empty[Double](ordering)

  override def contains(x: Double) = ts.contains(x)
  override def += (x: Double) = { if (!contains(x)) ts += x; this }
  override def -= (x: Double) = { ts -= x; this }  
  override def iterator: Iterator[Double] = ts.iterator
  override def keysIteratorFrom(start: Double): Iterator[Double] = ts.keysIteratorFrom(start)
  override def rangeImpl(from: Option[Double], until: Option[Double]) = ts.rangeImpl(from, until)
}