返回额外数据作为返回类型的一部分

时间:2015-03-04 17:49:07

标签: scala k-means

以下是kmeans算法的实现:

package com

object Functions {

  def distance(l1: (Array[Double], Array[Double])) = {

    val t = l1._1.zip(l1._2) 
    t.map(m => Math.abs(m._1 - m._2)).sum

  } 

}

package com

import com.Functions._

case class Point(label : String, points : Array[Double]) 

object KMeans2 extends Application {

  val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))

  var initialCenters = Iterable[Array[Double]](Array(2, 10), Array(5, 8), Array(1, 2))

  val toDouble = points.map(m => new Point(m.split(",").head , m.split(",").tail.map(m2 => m2.toDouble)))

  val k = 3
  val maxNumberOfIterations = 10
  for (i <- 1 to maxNumberOfIterations) {

    val newCentres = getNewCenters(initialCenters)
    initialCenters = newCentres._2
    val map = newCentres._1

    for (a <- initialCenters.toList) {
      println(a.toList)
      for(m <- map){
        println("m is "+m._1.mkString(",")+","+m._2.flatten)
      }
    }

    println("");
  }

  def getNewCenters(initialCenters: Iterable[Array[Double]]): (Map[Array[Double],List[Array[Double]]] , Iterable[Array[Double]]) = {

    val joined = toDouble.map(m => initialCenters.map(p => (m.points, p))).flatten

    val grouped = joined.map(m => (m, distance((m)))).grouped(3).toList
    val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)
    val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }

    val averaged = mapped.values.map(m => m.transpose.map(xs => xs.sum / xs.size).toArray)

    (mapped , averaged)

  }

}

我试图修改getNewCenters函数,以便返回每个集群的点标签。标签范围从A1 .... A9

我最接近的是返回与每个群集相关联的点数。这是返回的mapped Tuple

(mapped , average)元素

如何退回标签?

因此输出应该是(不一定是精确的数据结构)

List(1.5, 3.5)
m is 1.0,2.0,List(2.0, 5.0, A2, 1.0, 2.0, A7)
m is 2.0,10.0,List(2.0, 10.0 , A1)
m is 5.0,8.0,List(8.0, 4.0, A3, 5.0, 8.0, A4, 7.0, 5.0, A5, 6.0, 4.0, A6, 4.0, 9.0, A8)

而不是:

List(1.5, 3.5)
m is 1.0,2.0,List(2.0, 5.0, 1.0, 2.0)
m is 2.0,10.0,List(2.0, 10.0)
m is 5.0,8.0,List(8.0, 4.0, 5.0, 8.0, 7.0, 5.0, 6.0, 4.0, 4.0, 9.0)

更新:

这是我到目前为止所做的:

object kmeans {
  println("Welcome to the Scala worksheet")       //> Welcome to the Scala worksheet

case class Point(label : String, points : List[Double])

    val k = 3                                 //> k  : Int = 3

  def distance(l1: (Point, Point)) = {

    val t = l1._1.points.zip(l1._2.points)
    t.map(m => Math.abs(m._1 - m._2)).sum

  }                                               //> distance: (l1: (kmeans.Point, kmeans.Point))Double

  val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))
                                                  //> points  : List[String] = List(A1,2,10, A2,2,5, A3,8,4, A4,5,8, A5,7,5, A6,6,
                                                  //| 4, A7,1,2, A8,4,9)

  var initialCenters : List[Point] = List(Point("A1",List(2, 10)), Point("A4",List(5, 8)), Point("A7",List(1, 2)))
                                                  //> initialCenters  : List[kmeans.Point] = List(Point(A1,List(2.0, 10.0)), Point
                                                  //| (A4,List(5.0, 8.0)), Point(A7,List(1.0, 2.0)))

  val toDouble = points.map(m => new Point(m.split(",").head , m.split(",").tail.map(m2 => m2.toDouble).toList)).toList
                                                  //> toDouble  : List[kmeans.Point] = List(Point(A1,List(2.0, 10.0)), Point(A2,Li
                                                  //| st(2.0, 5.0)), Point(A3,List(8.0, 4.0)), Point(A4,List(5.0, 8.0)), Point(A5,
                                                  //| List(7.0, 5.0)), Point(A6,List(6.0, 4.0)), Point(A7,List(1.0, 2.0)), Point(A
                                                  //| 8,List(4.0, 9.0)))

  val joined = toDouble.map(m => initialCenters.map(p => (m, p))).flatten
                                                  //> joined  : List[(kmeans.Point, kmeans.Point)] = List((Point(A1,List(2.0, 10.0
                                                  //| )),Point(A1,List(2.0, 10.0))), (Point(A1,List(2.0, 10.0)),Point(A4,List(5.0,
                                                  //|  8.0))), (Point(A1,List(2.0, 10.0)),Point(A7,List(1.0, 2.0))), (Point(A2,Lis
                                                  //| t(2.0, 5.0)),Point(A1,List(2.0, 10.0))), (Point(A2,Li
                                                  //| Output exceeds cutoff limit.

   val grouped = joined.map(m => (m, distance(m))).grouped(k).toList
                                                  //> grouped  : List[List[((kmeans.Point, kmeans.Point), Double)]] = List(List(((
                                                  //| Point(A1,List(2.0, 10.0)),Point(A1,List(2.0, 10.0))),0.0), ((Point(A1,List(2
                                                  //| .0, 10.0)),Point(A4,List(5.0, 8.0))),5.0), ((Point(A1,List(2.0, 10.0)),Point
                                                  //| (A7,List(1.0, 2.0))),9.0)), List(((Point(A2,List(2.0,
                                                  //| Output exceeds cutoff limit.
    val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)
                                                  //> sorted  : List[(kmeans.Point, kmeans.Point)] = List((Point(A1,List(2.0, 10.0
                                                  //| )),Point(A1,List(2.0, 10.0))), (Point(A2,List(2.0, 5.0)),Point(A7,List(1.0, 
                                                  //| 2.0))), (Point(A3,List(8.0, 4.0)),Point(A4,List(5.0, 8.0))), (Point(A4,List(
                                                  //| 5.0, 8.0)),Point(A4,List(5.0, 8.0))), (Point(A5,List(
                                                  //| Output exceeds cutoff limit.
    val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }
                                                  //> mapped  : scala.collection.immutable.Map[kmeans.Point,List[kmeans.Point]] = 
                                                  //| Map(Point(A4,List(5.0, 8.0)) -> List(Point(A3,List(8.0, 4.0)), Point(A4,List
                                                  //| (5.0, 8.0)), Point(A5,List(7.0, 5.0)), Point(A6,List(6.0, 4.0)), Point(A8,Li
                                                  //| st(4.0, 9.0))), Point(A7,List(1.0, 2.0)) -> List(Poin
                                                  //| Output exceeds cutoff limit.

        val averaged = mapped.values.map(m => m.map(m2 => m2.points).transpose.map(xs => xs.sum / xs.size))
                                                  //> averaged  : Iterable[List[Double]] = List(List(6.0, 6.0), List(1.5, 3.5), L
                                                  //| ist(2.0, 10.0))



}

1 个答案:

答案 0 :(得分:0)

KMeans.scala:

package com.driver

import com.driver.Functions._

object KMeans extends Application {

  case class Point(label: String, points: List[Double])
  val k = 3

  val points = List(("A1,2,10"), ("A2,2,5"), ("A3,8,4"), ("A4,5,8"), ("A5,7,5"), ("A6,6,4"), ("A7,1,2"), ("A8,4,9"))

  val toDouble: List[Point] = points.map(m => new Point(m.split(",").head, m.split(",").tail.map(m2 => m2.toDouble).toList)).toList
  var initialCenters: List[Point] = util.Random.shuffle(toDouble).take(k)

  val maxNumberOfIterations = 10
  for (i <- 1 to maxNumberOfIterations) {

    initialCenters = getNewCenters(initialCenters , toDouble : List[Point])
    println(initialCenters);
  }

  def getNewCenters(initialCenters: List[Point] , toDouble : List[Point]) : List[Point] = {

    val joined = toDouble.map(m => initialCenters.map(p => (m, p))).flatten

    val grouped = joined.map(m => (m, distance(m))).grouped(k).toList

    val sorted = grouped.map(m => m.sortBy(_._2).take(1)).flatten.map(m => m._1)

    val mapped = sorted.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }

    val vv = mapped.values.map(p => (p.map(_.points)))

    try {   
      val averaged = mapped.values.map(p => (p.map(_.points).transpose.map(xs => xs.sum / xs.size), p.map(_.label)))
      println("averaged is "+averaged.map(m => m._2))

      averaged.map(m => Point("", m._1)).toList
    }
      catch {
       case e: java.lang.IllegalArgumentException => e.printStackTrace(); println("If reading from file, ensure values no blank lines")
       List[Point]()
    }

  }

}

Functions.scala:

package com.driver

import com.driver.KMeans.Point;

object Functions {

  def distance(l1: (Point, Point)) = {

    val t = l1._1.points.zip(l1._2.points)
    t.map(m => Math.abs(m._1 - m._2)).sum

  }


}