以下是kmeans算法的实现:
package com
object Functions {
def distance(l1: (Array[Double], Array[Double])) = {
val t = l1._1.zip(l1._2)
t.map(m => Math.abs(m._1 - m._2)).sum
}
}
package com
import com.Functions._
case class Point(label : String, points : Array[Double])
object KMeans2 extends Application {
val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))
var initialCenters = Iterable[Array[Double]](Array(2, 10), Array(5, 8), Array(1, 2))
val toDouble = points.map(m => new Point(m.split(",").head , m.split(",").tail.map(m2 => m2.toDouble)))
val k = 3
val maxNumberOfIterations = 10
for (i <- 1 to maxNumberOfIterations) {
val newCentres = getNewCenters(initialCenters)
initialCenters = newCentres._2
val map = newCentres._1
for (a <- initialCenters.toList) {
println(a.toList)
for(m <- map){
println("m is "+m._1.mkString(",")+","+m._2.flatten)
}
}
println("");
}
def getNewCenters(initialCenters: Iterable[Array[Double]]): (Map[Array[Double],List[Array[Double]]] , Iterable[Array[Double]]) = {
val joined = toDouble.map(m => initialCenters.map(p => (m.points, p))).flatten
val grouped = joined.map(m => (m, distance((m)))).grouped(3).toList
val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)
val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }
val averaged = mapped.values.map(m => m.transpose.map(xs => xs.sum / xs.size).toArray)
(mapped , averaged)
}
}
我试图修改getNewCenters函数,以便返回每个集群的点标签。标签范围从A1 .... A9
我最接近的是返回与每个群集相关联的点数。这是返回的mapped
Tuple
(mapped , average)
元素
如何退回标签?
因此输出应该是(不一定是精确的数据结构)
List(1.5, 3.5)
m is 1.0,2.0,List(2.0, 5.0, A2, 1.0, 2.0, A7)
m is 2.0,10.0,List(2.0, 10.0 , A1)
m is 5.0,8.0,List(8.0, 4.0, A3, 5.0, 8.0, A4, 7.0, 5.0, A5, 6.0, 4.0, A6, 4.0, 9.0, A8)
而不是:
List(1.5, 3.5)
m is 1.0,2.0,List(2.0, 5.0, 1.0, 2.0)
m is 2.0,10.0,List(2.0, 10.0)
m is 5.0,8.0,List(8.0, 4.0, 5.0, 8.0, 7.0, 5.0, 6.0, 4.0, 4.0, 9.0)
更新:
这是我到目前为止所做的:
object kmeans {
println("Welcome to the Scala worksheet") //> Welcome to the Scala worksheet
case class Point(label : String, points : List[Double])
val k = 3 //> k : Int = 3
def distance(l1: (Point, Point)) = {
val t = l1._1.points.zip(l1._2.points)
t.map(m => Math.abs(m._1 - m._2)).sum
} //> distance: (l1: (kmeans.Point, kmeans.Point))Double
val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))
//> points : List[String] = List(A1,2,10, A2,2,5, A3,8,4, A4,5,8, A5,7,5, A6,6,
//| 4, A7,1,2, A8,4,9)
var initialCenters : List[Point] = List(Point("A1",List(2, 10)), Point("A4",List(5, 8)), Point("A7",List(1, 2)))
//> initialCenters : List[kmeans.Point] = List(Point(A1,List(2.0, 10.0)), Point
//| (A4,List(5.0, 8.0)), Point(A7,List(1.0, 2.0)))
val toDouble = points.map(m => new Point(m.split(",").head , m.split(",").tail.map(m2 => m2.toDouble).toList)).toList
//> toDouble : List[kmeans.Point] = List(Point(A1,List(2.0, 10.0)), Point(A2,Li
//| st(2.0, 5.0)), Point(A3,List(8.0, 4.0)), Point(A4,List(5.0, 8.0)), Point(A5,
//| List(7.0, 5.0)), Point(A6,List(6.0, 4.0)), Point(A7,List(1.0, 2.0)), Point(A
//| 8,List(4.0, 9.0)))
val joined = toDouble.map(m => initialCenters.map(p => (m, p))).flatten
//> joined : List[(kmeans.Point, kmeans.Point)] = List((Point(A1,List(2.0, 10.0
//| )),Point(A1,List(2.0, 10.0))), (Point(A1,List(2.0, 10.0)),Point(A4,List(5.0,
//| 8.0))), (Point(A1,List(2.0, 10.0)),Point(A7,List(1.0, 2.0))), (Point(A2,Lis
//| t(2.0, 5.0)),Point(A1,List(2.0, 10.0))), (Point(A2,Li
//| Output exceeds cutoff limit.
val grouped = joined.map(m => (m, distance(m))).grouped(k).toList
//> grouped : List[List[((kmeans.Point, kmeans.Point), Double)]] = List(List(((
//| Point(A1,List(2.0, 10.0)),Point(A1,List(2.0, 10.0))),0.0), ((Point(A1,List(2
//| .0, 10.0)),Point(A4,List(5.0, 8.0))),5.0), ((Point(A1,List(2.0, 10.0)),Point
//| (A7,List(1.0, 2.0))),9.0)), List(((Point(A2,List(2.0,
//| Output exceeds cutoff limit.
val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)
//> sorted : List[(kmeans.Point, kmeans.Point)] = List((Point(A1,List(2.0, 10.0
//| )),Point(A1,List(2.0, 10.0))), (Point(A2,List(2.0, 5.0)),Point(A7,List(1.0,
//| 2.0))), (Point(A3,List(8.0, 4.0)),Point(A4,List(5.0, 8.0))), (Point(A4,List(
//| 5.0, 8.0)),Point(A4,List(5.0, 8.0))), (Point(A5,List(
//| Output exceeds cutoff limit.
val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }
//> mapped : scala.collection.immutable.Map[kmeans.Point,List[kmeans.Point]] =
//| Map(Point(A4,List(5.0, 8.0)) -> List(Point(A3,List(8.0, 4.0)), Point(A4,List
//| (5.0, 8.0)), Point(A5,List(7.0, 5.0)), Point(A6,List(6.0, 4.0)), Point(A8,Li
//| st(4.0, 9.0))), Point(A7,List(1.0, 2.0)) -> List(Poin
//| Output exceeds cutoff limit.
val averaged = mapped.values.map(m => m.map(m2 => m2.points).transpose.map(xs => xs.sum / xs.size))
//> averaged : Iterable[List[Double]] = List(List(6.0, 6.0), List(1.5, 3.5), L
//| ist(2.0, 10.0))
}
答案 0 :(得分:0)
KMeans.scala:
package com.driver
import com.driver.Functions._
object KMeans extends Application {
case class Point(label: String, points: List[Double])
val k = 3
val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))
val toDouble: List[Point] = points.map(m => new Point(m.split(",").head, m.split(",").tail.map(m2 => m2.toDouble).toList)).toList
var initialCenters: List[Point] = util.Random.shuffle(toDouble).take(k)
val maxNumberOfIterations = 10
for (i <- 1 to maxNumberOfIterations) {
initialCenters = getNewCenters(initialCenters , toDouble : List[Point])
println(initialCenters);
}
def getNewCenters(initialCenters: List[Point] , toDouble : List[Point]) : List[Point] = {
val joined = toDouble.map(m => initialCenters.map(p => (m, p))).flatten
val grouped = joined.map(m => (m, distance(m))).grouped(k).toList
val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)
val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }
val vv = mapped.values.map(p => (p.map(_.points)))
try {
val averaged = mapped.values.map(p => (p.map(_.points).transpose.map(xs => xs.sum / xs.size), p.map(_.label)))
println("averaged is "+averaged.map(m => m._2))
averaged.map(m => Point("", m._1)).toList
}
catch {
case e: java.lang.IllegalArgumentException => e.printStackTrace(); println("If reading from file, ensure values no blank lines")
List[Point]()
}
}
}
Functions.scala:
package com.driver
import com.driver.KMeans.Point;
object Functions {
def distance(l1: (Point, Point)) = {
val t = l1._1.points.zip(l1._2.points)
t.map(m => Math.abs(m._1 - m._2)).sum
}
}