Scala 中值实现

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/4662292/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-10-22 02:43:03  来源:igfitidea点击:

scala median implementation

algorithmscalamedian

提问by dsg

What's a fast implementation of median in scala?

什么是scala中中位数的快速实现?

This is what I found on rosetta code:

这是我在rosetta 代码上发现的:

  def median(s: Seq[Double])  =
  {
    val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
    if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
  }

I don't like it because it does a sort. I know there are ways to compute the median in linear time.

我不喜欢它,因为它有点像。我知道有一些方法可以计算线性时间的中位数。

EDIT:

编辑:

I would like to have a set of median functions that I can use in various scenarios:

我想要一组可以在各种场景中使用的中值函数:

  1. fast, in place median computation that can be done in linear time
  2. median that works on a stream that you can traverse multiple times, but you can only keep O(log n)values in memory like this
  3. median that works on a stream, where you can hold at most O(log n)values in memory, and you can traverse the stream at most once (is this even possible?)
  1. 可以在线性时间内完成的快速、就地中值计算
  2. 中值适用于您可以多次遍历的流,但您只能像这样O(log n)值保留在内存中
  3. 适用于流的中O(log n)值,您最多可以在内存中保存值,并且您最多可以遍历流一次(这甚至可能吗?)

Please only post code that compilesand correctly computes the median. For simplicity, you may assume that all inputs contain an odd number of values.

请只发布编译正确计算中位数的代码。为简单起见,您可以假设所有输入都包含奇数个值。

回答by Daniel C. Sobral

Immutable Algorithm

不可变算法

The first algorithmindicatedby Taylor Leeseis quadratic, but has linear average. That, however, depends on the pivot selection. So I'm providing here a version which has a pluggable pivot selection, and both the random pivot and the median of medians pivot (which guarantees linear time).

所述第一算法表示通过泰勒利斯是二次的,而是具有线性平均。但是,这取决于枢轴选择。所以我在这里提供了一个版本,它有一个可插入的枢轴选择,以及随机枢轴和中值枢轴的中位数(保证线性时间)。

import scala.annotation.tailrec

@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partition (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partition (a ==)
        if (s.size > k) a
        else findKMedian(b, k - s.size)
    } else if (s.size < k) findKMedian(b, k - s.size)
    else findKMedian(s, k)
}

def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)

Random Pivot (quadratic, linear average), Immutable

随机枢轴(二次,线性平均),不可变

This is the random pivot selection. Analysis of algorithms with random factors is trickier than normal, because it deals largely with probability and statistics.

这是随机枢轴选择。分析具有随机因素的算法比正常情况更棘手,因为它主要涉及概率和统计。

def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))

Median of Medians (linear), Immutable

中位数(线性),不可变

The median of medians method, which guarantees linear time when used with the algorithm above. First, and algorithm to compute the median of up to 5 numbers, which is the basis of the median of medians algorithm. This one was provided by Rex Kerrin this answer-- the algorithm depends a lot on the speed of it.

中位数方法的中位数,当与上述算法一起使用时保证线性时间。首先,算法计算最多5个数字的中位数,这是中位数算法的基础。这个是由Rex Kerr这个答案中提供的——算法在很大程度上取决于它的速度。

def medianUpTo5(five: Array[Double]): Double = {
  def order2(a: Array[Double], i: Int, j: Int) = {
    if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
  }

  def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
    if (a(i)<a(k)) { order2(a,j,k); a(j) }
    else { order2(a,i,l); a(i) }
  }

  if (five.length < 2) return five(0)
  order2(five,0,1)
  if (five.length < 4) return (
    if (five.length==2 || five(2) < five(0)) five(0)
    else if (five(2) > five(1)) five(1)
    else five(2)
  )
  order2(five,2,3)
  if (five.length < 5) pairs(five,0,1,2,3)
  else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
  else { order2(five,3,4); pairs(five,0,1,3,4) }
}

And, then, the median of medians algorithm itself. Basically, it guarantees that the choosen pivot will be greater than at least 30% and smaller than other 30% of the list, which is enough to guarantee the linearity of the previous algorithm. Look up the wikipedia link provided in another answer for details.

然后,中位数算法本身的中位数。基本上,它保证选择的pivot至少大于30%,小于列表的其他30%,足以保证前面算法的线性。有关详细信息,请查看另一个答案中提供的维基百科链接。

def medianOfMedians(arr: Array[Double]): Double = {
    val medians = arr grouped 5 map medianUpTo5 toArray;
    if (medians.size <= 5) medianUpTo5 (medians)
    else medianOfMedians(medians)
}

In-place Algorithm

就地算法

So, here's an in-place version of the algorithm. I'm using a class that implements a partition in-place, with a backing array, so that the changes to the algorithms are minimal.

因此,这是该算法的就地版本。我正在使用一个类来实现就地分区,并带有一个支持数组,因此对算法的更改最小。

case class ArrayView(arr: Array[Double], from: Int, until: Int) {
    def apply(n: Int) = 
        if (from + n < until) arr(from + n)
        else throw new ArrayIndexOutOfBoundsException(n)

    def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
      var upper = until - 1
      var lower = from
      while (lower < upper) {
        while (lower < until && p(arr(lower))) lower += 1
        while (upper >= from && !p(arr(upper))) upper -= 1
        if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
      }
      (copy(until = lower), copy(from = lower))
    }

    def size = until - from
    def isEmpty = size <= 0

    override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
    def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}

@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partitionInPlace (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partitionInPlace (a ==)
        if (s.size > k) a
        else findKMedianInPlace(b, k - s.size)
    } else if (s.size < k) findKMedianInPlace(b, k - s.size)
    else findKMedianInPlace(s, k)
}

def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)

Random Pivot, In-place

随机枢轴,就地

I'm only implementing the radom pivot for the in-place algorithms, as the median of medians would require more support than what is presently provided by the ArrayViewclass I defined.

我只是为就地算法实现随机枢轴,因为中位数需要比ArrayView我定义的类目前提供的支持更多的支持。

def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))

Histogram Algorithm (O(log(n)) memory), Immutable

直方图算法(O(log(n)) 内存),不可变

So, about streams. It is impossible to do anything less than O(n)memory for a stream that can only be traversed once, unless you happen to know what the string length is (in which case it ceases to be a stream in my book).

所以,关于流。O(n)对于只能遍历一次的流,除了内存之外不可能做任何事情,除非您碰巧知道字符串长度是多少(在这种情况下,它不再是我书中的流)。

Using buckets is also a bit problematic, but if we can traverse it multiple times, then we can know its size, maximum and minimum, and work from there. For example:

使用桶也有点问题,但是如果我们可以多次遍历它,那么我们就可以知道它的大小、最大值和最小值,并从那里开始工作。例如:

def findMedianHistogram(s: Traversable[Double]) = {
    def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
        // The buckets
        def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
        val buckets = new Array[Int](numberOfBuckets)

        // The upper limit of each bucket
        val max = s.max
        val min = s.min
        val increment = (max - min) / numberOfBuckets
        val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)

        // Return the bucket a number is supposed to be in
        def bucketIndex(d: Double) = indices indexWhere (d <=)

        // Compute how many in each bucket
        s foreach { d => buckets(bucketIndex(d)) += 1 }

        // Now make the buckets cumulative
        val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)

        // The bucket where our target is at
        val medianBucket = partialTotals indexWhere (medianIndex <)

        // Keep track of how many numbers there are that are less 
        // than the median bucket
        val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)

        // Test whether a number is in the median bucket
        def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket

        // Get a view of the target bucket
        val view = s.view filter insideMedianBucket

        // If all numbers in the bucket are equal, return that
        if (view forall (view.head ==)) view.head
        // Otherwise, recurse on that bucket
        else medianHistogram(view, newDiscarded, medianIndex)
    }

    medianHistogram(s, 0, (s.size - 1) / 2)
}

Test and Benchmark

测试和基准

To test the algorithms, I'm using Scalacheck, and comparing the output of each algorithm to the output of a trivial implementation with sorting. That assumes the sorting version is correct, of course.

为了测试算法,我使用Scalacheck,并将每个算法的输出与带有排序的简单实现的输出进行比较。当然,这假设排序版本是正确的。

I'm benchmarking each of the above algorithms with all provided pivot selections, plus a fixed pivot selection (halfway the array, round down). Each algorithm is tested with three different input array sizes, and for three times against each one.

我正在使用所有提供的枢轴选择以及固定的枢轴选择(数组的一半,向下取整)对上述每个算法进行基准测试。每种算法都使用三种不同的输入数组大小进行测试,并针对每种算法进行 3 次测试。

Here's the testing code:

下面是测试代码:

import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    }
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}

import java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start
}

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

def testAndBenchmark: String = {
    val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
        "Random Pivot"      -> chooseRandomPivot,
        "Median of Medians" -> medianOfMedians,
        "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
    )
    val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
        "Random Pivot (in-place)" -> chooseRandomPivotInPlace,
        "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
    )
    val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))
    val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
        yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
    val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
    val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
    val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms

    val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)

    // Tests
    val testResults = for ((name, algorithm) <- algorithms)
        yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

    // Benchmarks
    val arraySizes = List(100, 500, 1000)
    def formatResults(results: List[Long]) = results map ("%8d" format _) mkString

    val benchmarkResults: List[String] = for {
        (name, algorithm) <- algorithms
        results <- benchmark(algorithm, arraySizes).transpose
    } yield formattingString format (name, formatResults(results))

    val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))

    "Tests" :: "*****" :: testResults ::: 
    ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}

Results

结果

Tests:

测试:

Tests
*****
Sorting                OK, passed 100 tests.
Histogram              OK, passed 100 tests.
Random Pivot           OK, passed 100 tests.
Median of Medians      OK, passed 100 tests.
Midpoint               OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place)    OK, passed 100 tests.

Benchmarks:

基准:

Benchmark
*********
Algorithm                   100     500    1000
Sorting                    1038    6230   14034
Sorting                    1037    6223   13777
Sorting                    1039    6220   13785
Histogram                  2918   11065   21590
Histogram                  2596   11046   21486
Histogram                  2592   11044   21606
Random Pivot                904    4330    8622
Random Pivot                902    4323    8815
Random Pivot                896    4348    8767
Median of Medians          3591   16857   33307
Median of Medians          3530   16872   33321
Median of Medians          3517   16793   33358
Midpoint                   1003    4672    9236
Midpoint                   1010    4755    9157
Midpoint                   1017    4663    9166
Random Pivot (in-place)     392    1746    3430
Random Pivot (in-place)     386    1747    3424
Random Pivot (in-place)     386    1751    3431
Midpoint (in-place)         378    1735    3405
Midpoint (in-place)         377    1740    3408
Midpoint (in-place)         375    1736    3408

Analysis

分析

All algorithms (except the sorting version) have results that are compatible with average linear time complexity.

所有算法(排序版本除外)的结果都与平均线性时间复杂度兼容。

The median of medians, which guarantees linear time complexity in the worst case is much slower than the random pivot.

在最坏情况下保证线性时间复杂度的中位数比随机枢轴慢得多。

The fixed pivot selection is slightly worse than random pivot, but may have much worse performance on non-random inputs.

固定支点选择比随机支点略差,但在非随机输入上的性能可能更差。

The in-place version is about 230% ~ 250% faster, but further tests (not shown) seem to indicate this advantage grows with the size of the array.

就地版本大约快 230% ~ 250%,但进一步的测试(未显示)似乎表明这种优势随着阵列的大小而增长。

I was very surprised by the histogram algorithm. It displayed linear time complexity average, and it's also 33% faster than the median of medians. However, the input israndom. The worst case is quadratic -- I saw some examples of it while I was debugging the code.

我对直方图算法感到非常惊讶。它显示了线性时间复杂度平均值,并且比中位数的中位数快 33%。但是,输入随机的。最坏的情况是二次方——我在调试代码时看到了一些例子。