我正在尝试修改下面的代码以获取第三个Point对象参数,但是这一行:
val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) =>
导致此编译时错误:
Multiple markers at this line
- type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) =>
scala.algorithms.Point required: (?, scala.algorithms.Point) => ?
- type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) =>
scala.algorithms.Point required: (?, scala.algorithms.Point) => ?
整个代码:
package scala.algorithms
/**
* Modified from http://garysieling.com/blog/implementing-k-means-in-scala
*
*/
class Point(val x: Double, val y: Double, val z : Double) {
override def toString(): String = {
"(" + x + ", " + y + ")"
}
def dist(p: Point): Double = {
x * x + y * y + z * z
}
}
object kmeans extends App {
val NUMBER_OF_CLUSTERS = 5;
val k: Int = 2
val points: List[Point] = List(
new Point(0, 0, 1),
new Point(1, 0, 1),
new Point(0, 1, 0)).sortBy(
p => (p.x + " " + p.y).hashCode())
def clusterMean(points: List[Point]): Point = {
val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) =>
new Point(a.x + b.x + c.x, a.y + b.y + c.y , a.z + b.z + c.z))
new Point(cumulative.x / points.length, cumulative.y / points.length
, cumulative.z / points.length)
}
def render(points: Map[Int, List[Point]]) {
for (clusterNumber <- points.keys.toSeq.sorted) {
println(" Cluster " + clusterNumber)
val meanPoint = clusterMean(points(clusterNumber))
println(" Mean: " + meanPoint)
for (j <- 0 to points(clusterNumber).length - 1) {
System.out.println(" " + points(clusterNumber)(j) + ")")
}
}
}
val clusters =
points.zipWithIndex.groupBy(
x => x._2 % k) transform (
(i: Int, p: List[(Point, Int)]) => for (x <- p) yield x._1)
println("Initial State: ")
render(clusters)
def iterate(clusters: Map[Int, List[Point]]): Map[Int, List[Point]] = {
val unzippedClusters =
(clusters: Iterator[(Point, Int)]) => clusters.map(cluster => cluster._1)
// find cluster means
val means =
(clusters: Map[Int, List[Point]]) =>
for (clusterIndex <- clusters.keys)
yield clusterMean(clusters(clusterIndex))
// find the closest index
def closest(p: Point, means: Iterable[Point]): Int = {
val distances = for (center <- means) yield p.dist(center)
distances.zipWithIndex.min._2
}
// assignment step
val newClusters =
points.groupBy(
(p: Point) => closest(p, means(clusters)))
render(newClusters)
newClusters
}
var clusterToTest = clusters
for (i <- 0 to NUMBER_OF_CLUSTERS) {
System.out.println("Iteration: " + i)
clusterToTest = iterate(clusterToTest)
}
}
阅读从http://www.scala-lang.org/api/current/index.html#index.index-r获取的reduceLeft方法的文档:
Applies a binary operator to all elements of this sequence, going left to right.
我想我需要在这里更改使用的方法吗?
reduceLeft方法也有多个特征:
IndexedSeqOptimized LinearSeqOptimized TraversableOnce TraversableProxyLike TraversableForwarder Stream ParIterableLike
,我怎么知道正在实施哪个trait / reduceLeft实现?
答案 0 :(得分:0)
方法reduceLeft
接受带有2个参数的函数作为参数,因此您应该像这样使用它:
points.reduce( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))
请注意,您将在空points
上获得例外。您可以使用reduceOption
或折叠以避免例外:
points.fold(new Point(0, 0, 0))( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))
您可以使用documentation来调查实施方法的位置:
定义类TraversableOnce
来自reduceLeft
说明。