scala中间值的快速实现是什么?
这是我在rosetta code上找到的:
def median(s: Seq[Double]) =
{
val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
}
我不喜欢它,因为它做了一种。我知道有很多方法可以计算线性时间的中位数。
修改
我想拥有一组可以在各种场景中使用的中值函数:
O(log n)
值保留在内存中like this O(log n)
个值,并且您最多可以遍历一次流(这是否可能?)请仅发布编译的代码,正确计算中位数。为简单起见,您可以假设所有输入都包含奇数个值。
答案 0 :(得分:56)
first algorithm的indicated Taylor Leese是二次的,但具有线性平均值。然而,这取决于枢轴选择。所以我在这里提供了一个具有可插入枢轴选择的版本,并且随机枢轴和中位数的中位数都是枢轴(这保证了线性时间)。
import scala.annotation.tailrec
@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partition (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty) {
val (s, b) = arr partition (a ==)
if (s.size > k) a
else findKMedian(b, k - s.size)
} else if (s.size < k) findKMedian(b, k - s.size)
else findKMedian(s, k)
}
def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
这是随机数据透视选择。使用随机因子的算法分析比正常情况更棘手,因为它主要处理概率和统计。
def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
中位数方法的中位数,与上述算法一起使用时可保证线性时间。首先,算法计算最多5个数字的中位数,这是中位数算法的中位数的基础。这个是由Rex Kerr中的this answer提供的 - 算法在很大程度上取决于它的速度。
def medianUpTo5(five: Array[Double]): Double = {
def order2(a: Array[Double], i: Int, j: Int) = {
if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
}
def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
if (a(i)<a(k)) { order2(a,j,k); a(j) }
else { order2(a,i,l); a(i) }
}
if (five.length < 2) return five(0)
order2(five,0,1)
if (five.length < 4) return (
if (five.length==2 || five(2) < five(0)) five(0)
else if (five(2) > five(1)) five(1)
else five(2)
)
order2(five,2,3)
if (five.length < 5) pairs(five,0,1,2,3)
else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
else { order2(five,3,4); pairs(five,0,1,3,4) }
}
然后,中位数算法本身的中位数。基本上,它保证选择的枢轴将大于至少30%并且小于列表的其他30%,这足以保证先前算法的线性。查看另一个答案中提供的维基百科链接以获取详细信息。
def medianOfMedians(arr: Array[Double]): Double = {
val medians = arr grouped 5 map medianUpTo5 toArray;
if (medians.size <= 5) medianUpTo5 (medians)
else medianOfMedians(medians)
}
所以,这是算法的就地版本。我正在使用一个使用支持数组实现分区的类,以便对算法的更改最小化。
case class ArrayView(arr: Array[Double], from: Int, until: Int) {
def apply(n: Int) =
if (from + n < until) arr(from + n)
else throw new ArrayIndexOutOfBoundsException(n)
def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
var upper = until - 1
var lower = from
while (lower < upper) {
while (lower < until && p(arr(lower))) lower += 1
while (upper >= from && !p(arr(upper))) upper -= 1
if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
}
(copy(until = lower), copy(from = lower))
}
def size = until - from
def isEmpty = size <= 0
override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}
@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partitionInPlace (a >)
if (s.size == k) a
// The following test is used to avoid infinite repetition
else if (s.isEmpty) {
val (s, b) = arr partitionInPlace (a ==)
if (s.size > k) a
else findKMedianInPlace(b, k - s.size)
} else if (s.size < k) findKMedianInPlace(b, k - s.size)
else findKMedianInPlace(s, k)
}
def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
我只是为就地算法实现了radom pivot,因为中位数的中位数需要比我定义的ArrayView
类目前提供的支持更多的支持。
def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
所以,关于流。对于只能遍历一次的流,不可能为O(n)
内存执行任何操作,除非您碰巧知道字符串长度是什么(在这种情况下它不再是我书中的流)。 / p>
使用存储桶也有点问题,但如果我们可以多次遍历它,那么我们就可以知道它的大小,最大值和最小值,并从那里开始工作。例如:
def findMedianHistogram(s: Traversable[Double]) = {
def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
// The buckets
def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
val buckets = new Array[Int](numberOfBuckets)
// The upper limit of each bucket
val max = s.max
val min = s.min
val increment = (max - min) / numberOfBuckets
val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)
// Return the bucket a number is supposed to be in
def bucketIndex(d: Double) = indices indexWhere (d <=)
// Compute how many in each bucket
s foreach { d => buckets(bucketIndex(d)) += 1 }
// Now make the buckets cumulative
val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)
// The bucket where our target is at
val medianBucket = partialTotals indexWhere (medianIndex <)
// Keep track of how many numbers there are that are less
// than the median bucket
val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)
// Test whether a number is in the median bucket
def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket
// Get a view of the target bucket
val view = s.view filter insideMedianBucket
// If all numbers in the bucket are equal, return that
if (view forall (view.head ==)) view.head
// Otherwise, recurse on that bucket
else medianHistogram(view, newDiscarded, medianIndex)
}
medianHistogram(s, 0, (s.size - 1) / 2)
}
为了测试算法,我使用Scalacheck,并将每个算法的输出与通过排序的简单实现的输出进行比较。当然,这假设排序版本是正确的。
我正在使用所有提供的枢轴选择对上述每个算法进行基准测试,加上固定的枢轴选择(阵列的一半,向下舍入)。每个算法都使用三种不同的输入数组大小进行测试,并针对每种算法进行三次测试。
这是测试代码:
import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._
def test(algorithm: Array[Double] => Double,
reference: Array[Double] => Double): String = {
def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
val resultEqualsReference = forAll { (arr: Array[Double]) =>
arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
}
Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}
import java.lang.System.currentTimeMillis
def bench[A](n: Int)(body: => A): Long = {
val start = currentTimeMillis()
1 to n foreach { _ => body }
currentTimeMillis() - start
}
import scala.util.Random.nextDouble
def benchmark(algorithm: Array[Double] => Double,
arraySizes: List[Int]): List[Iterable[Long]] =
for (size <- arraySizes)
yield for (iteration <- 1 to 3)
yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))
def testAndBenchmark: String = {
val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
"Random Pivot" -> chooseRandomPivot,
"Median of Medians" -> medianOfMedians,
"Midpoint" -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
)
val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
"Random Pivot (in-place)" -> chooseRandomPivotInPlace,
"Midpoint (in-place)" -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
)
val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
yield name -> (findMedian(_: Array[Double])(pivotSelection))
val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms
val formattingString = "%%-%ds %%s" format (algorithms map (_._1.length) max)
// Tests
val testResults = for ((name, algorithm) <- algorithms)
yield formattingString format (name, test(algorithm, sortingAlgorithm._2))
// Benchmarks
val arraySizes = List(100, 500, 1000)
def formatResults(results: List[Long]) = results map ("%8d" format _) mkString
val benchmarkResults: List[String] = for {
(name, algorithm) <- algorithms
results <- benchmark(algorithm, arraySizes).transpose
} yield formattingString format (name, formatResults(results))
val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))
"Tests" :: "*****" :: testResults :::
("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}
试验:
Tests
*****
Sorting OK, passed 100 tests.
Histogram OK, passed 100 tests.
Random Pivot OK, passed 100 tests.
Median of Medians OK, passed 100 tests.
Midpoint OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place) OK, passed 100 tests.
基准:
Benchmark
*********
Algorithm 100 500 1000
Sorting 1038 6230 14034
Sorting 1037 6223 13777
Sorting 1039 6220 13785
Histogram 2918 11065 21590
Histogram 2596 11046 21486
Histogram 2592 11044 21606
Random Pivot 904 4330 8622
Random Pivot 902 4323 8815
Random Pivot 896 4348 8767
Median of Medians 3591 16857 33307
Median of Medians 3530 16872 33321
Median of Medians 3517 16793 33358
Midpoint 1003 4672 9236
Midpoint 1010 4755 9157
Midpoint 1017 4663 9166
Random Pivot (in-place) 392 1746 3430
Random Pivot (in-place) 386 1747 3424
Random Pivot (in-place) 386 1751 3431
Midpoint (in-place) 378 1735 3405
Midpoint (in-place) 377 1740 3408
Midpoint (in-place) 375 1736 3408
所有算法(排序版本除外)都具有与平均线性时间复杂度兼容的结果。
在最坏的情况下保证线性时间复杂度的中位数中位数比随机数据点慢得多。
固定枢轴选择稍微差于随机旋转,但在非随机输入上可能会有更差的性能。
就地版本的速度提高了大约230%~250%,但是进一步的测试(未显示)似乎表明这种优势会随着阵列的大小而增加。
我对直方图算法感到非常惊讶。它显示了平均线性时间复杂度,并且比中位数的中位数快33%。但是,输入是随机的。最坏的情况是二次的 - 我在调试代码时看到了一些例子。