使用Scala 2.10.2 - 我想构建一组有序的浮点值,并根据指定的容差进行比较,如下例所示:
implicit object DiffAtLeastOne extends Ordering[Double] {
def compare(a: Double, b: Double): Int =
if ((a - b).abs < 1.0) 0
else a.compare(b)
}
val ts = scala.collection.mutable.TreeSet.empty[Double]
ts += 0.0
ts += 0.9
ts += 1.8
println( ts ) // prints TreeSet(1.8)
我曾预料到,因为连续的值被视为等价,所以该集将保留第一个(0.0)和最后一个(1.8),而是用下一个值替换每个先前的值。有没有一种简单的方法可以防止这种情况而不对TreeSet进行子类化(它给出了弃用警告)?
如果这是重复的道歉 - 我拖了一段时间没有成功。
更新
在查看this related answer和@ user2864740的评论后,我意识到我需要使我的值类型的equals方法符合Ordering.compare方法。以下是丑陋的,但给出了我追求的行为:
case class RoundedDouble(value: Double) {
import RoundedDouble._
override def equals(other: Any): Boolean =
if (other.isInstanceOf[RoundedDouble]) {
val otherRD = other.asInstanceOf[RoundedDouble]
DiffAtLeastOne.compare(this, otherRD) == 0
}
else false
}
object RoundedDouble {
implicit object DiffAtLeastOne extends Ordering[RoundedDouble] {
def compare(a: RoundedDouble, b: RoundedDouble): Int =
if ((a.value - b.value).abs < 1.0) 0
else a.value.compare(b.value)
}
implicit def fromDouble(d: Double) = RoundedDouble(d)
}
val ts = scala.collection.mutable.TreeSet.empty[RoundedDouble]
ts += 0.0
ts += 0.9
ts += 1.8
println( ts ) // prints TreeSet(RoundedDouble(0.0), RoundedDouble(1.8))
毫无疑问,有一种更优雅的选择。
答案 0 :(得分:0)
回答我自己的问题(但仍希望得到更好的答案)......
以下构图解决方案(受this answer启发)似乎可以满足我的需求:
import scala.collection.mutable.{SortedSet, TreeSet}
object RoundedDoubleSortedSet {
val Tolerance = 1.0e-8
def almostEq(a: Double, b: Double) =
(a - b).abs < Tolerance
val ordering = new Ordering[Double] {
def compare(a: Double, b: Double): Int =
if (almostEq(a, b)) 0
else a.compare(b)
}
}
final class RoundedDoubleSortedSet extends SortedSet[Double] {
override val ordering: Ordering[Double] = RoundedDoubleSortedSet.ordering
private val ts = TreeSet.empty[Double](ordering)
override def contains(x: Double) = ts.contains(x)
override def += (x: Double) = { if (!contains(x)) ts += x; this }
override def -= (x: Double) = { ts -= x; this }
override def iterator: Iterator[Double] = ts.iterator
override def keysIteratorFrom(start: Double): Iterator[Double] = ts.keysIteratorFrom(start)
override def rangeImpl(from: Option[Double], until: Option[Double]) = ts.rangeImpl(from, until)
}