在这种情况下我应该使用reduceLeft方法吗?

时间:2013-07-12 11:22:03

标签: scala

我正在尝试修改下面的代码以获取第三个Point对象参数,但是这一行:

val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 

导致此编译时错误:

Multiple markers at this line
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
     scala.algorithms.Point required: (?, scala.algorithms.Point) => ?
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
     scala.algorithms.Point required: (?, scala.algorithms.Point) => ?

整个代码:

package scala.algorithms

/**
 * Modified from http://garysieling.com/blog/implementing-k-means-in-scala
 * 
 */

class Point(val x: Double, val y: Double, val z : Double) {

  override def toString(): String = {
    "(" + x + ", " + y + ")"
  }

  def dist(p: Point): Double = {
    x * x + y * y + z * z
  }
}

object kmeans extends App {

   val NUMBER_OF_CLUSTERS = 5;

  val k: Int = 2

  val points: List[Point] = List(
    new Point(0, 0, 1),
    new Point(1, 0, 1),
    new Point(0, 1, 0)).sortBy(
      p => (p.x + " " + p.y).hashCode())

  def clusterMean(points: List[Point]): Point = {
    val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 
      new Point(a.x + b.x + c.x, a.y + b.y + c.y , a.z + b.z + c.z))

    new Point(cumulative.x / points.length, cumulative.y / points.length
        , cumulative.z / points.length)
  }

  def render(points: Map[Int, List[Point]]) {
    for (clusterNumber <- points.keys.toSeq.sorted) {
      println("  Cluster " + clusterNumber)

      val meanPoint = clusterMean(points(clusterNumber))
      println("  Mean: " + meanPoint)

      for (j <- 0 to points(clusterNumber).length - 1) {
        System.out.println("    " + points(clusterNumber)(j) + ")")
      }
    }
  }

  val clusters =
    points.zipWithIndex.groupBy(
      x => x._2 % k) transform (
        (i: Int, p: List[(Point, Int)]) => for (x <- p) yield x._1)

  println("Initial State: ")
  render(clusters)

  def iterate(clusters: Map[Int, List[Point]]): Map[Int, List[Point]] = {
    val unzippedClusters =
      (clusters: Iterator[(Point, Int)]) => clusters.map(cluster => cluster._1)

    // find cluster means
    val means =
      (clusters: Map[Int, List[Point]]) =>
        for (clusterIndex <- clusters.keys)
          yield clusterMean(clusters(clusterIndex))

    // find the closest index
    def closest(p: Point, means: Iterable[Point]): Int = {
      val distances = for (center <- means) yield p.dist(center)
      distances.zipWithIndex.min._2
    }

    // assignment step
    val newClusters =
      points.groupBy(
        (p: Point) => closest(p, means(clusters)))

    render(newClusters)

    newClusters
  }

  var clusterToTest = clusters
  for (i <- 0 to NUMBER_OF_CLUSTERS) {
    System.out.println("Iteration: " + i)
    clusterToTest = iterate(clusterToTest)
  }
}

阅读从http://www.scala-lang.org/api/current/index.html#index.index-r获取的reduceLeft方法的文档:

Applies a binary operator to all elements of this sequence, going left to right.

我想我需要在这里更改使用的方法吗?

reduceLeft方法也有多个特征:

IndexedSeqOptimized LinearSeqOptimized TraversableOnce TraversableProxyLike TraversableForwarder Stream ParIterableLike

,我怎么知道正在实施哪个trait / reduceLeft实现?

1 个答案:

答案 0 :(得分:0)

方法reduceLeft接受带有2个参数的函数作为参数,因此您应该像这样使用它:

points.reduce( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))

请注意,您将在空points上获得例外。您可以使用reduceOption或折叠以避免例外:

points.fold(new Point(0, 0, 0))( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))

您可以使用documentation来调查实施方法的位置:

  

定义类TraversableOnce

来自reduceLeft说明。