Scala:k折叠交叉验证

时间:2016-02-24 08:33:02

标签: scala

我想进行k折交叉验证。基本上我们给了一堆数据allData。假设我们将输入划分为" k"群集并将其放入groups

所需的输出为trainAndTestDataList: List[(Iterable[T], Iterable[T])],其中List的大小为" k"。 trainAndTestDataList的" i"元素是像(A,B)这样的元组,其中A应该是groups的" i"元素和B应该是groups的所有元素,除了" i" th,连接。

有效实施此方法的任何想法?

val allData: Iterable[T] = ... // we get the data from somewhere 

val groupSize = Math.ceil(allData.size / k).toInt
val groups = allData.grouped(groupSize).toList

val trainAndTestDataList = ... // fill out this part 

要记住的一件事是allData可能会很长,但是" k"非常小(比如5)。因此,将所有数据向量保持为Iterator(而非ListSeq等)非常重要。

更新:这是我的方式(我对此并不满意):

val trainAndTestDataList = {
  (0 until k).map{ fold =>
    val (a,b) = groups.zipWithIndex.partition{case (g, idx) => idx == fold}
    (a.unzip._1.flatten.toIterable, b.unzip._1.flatten.toIterable)
  }
}

我不喜欢它的原因:

  1. 特别是在我partition unzip._1flatten a后扭曲。我认为一个人应该能够做得更好。
  2. 虽然Iterable[T]a.unzip._1.flatten.,但我认为List[T]的输出是{{1}}。这不好,因为此列表中元素的数量可能非常大。

2 个答案:

答案 0 :(得分:2)

您可以尝试该操作

\1

此分组与implicit class TeeSplitOp[T](data: Iterable[T]) { def teeSplit(count: Int): Stream[(Iterable[T], Iterable[T])] = { val size = data.size def piece(i: Int) = i * size / count Stream.range(0, size - 1) map { i => val (prefix, rest) = data.splitAt(piece(i)) val (test, postfix) = rest.splitAt(piece(i + 1) - piece(i)) val train = prefix ++ postfix (test, train) } } } 一样懒惰,splitAt在您的集合类型中。

您可以尝试

++

答案 1 :(得分:0)

我相信这应该有效。它还以合理有效的方式处理随机化(不要忽略这一点!),即使用随机shuffle /置换的更天真的方法所需的O(n)而不是O(n log(n))数据。

import scala.util.Random

def testTrainDataList[T](
  data: Seq[T], 
  k: Int, 
  seed: Long = System.currentTimeMillis()
): Seq[(Iterable[T], Iterable[T])] = {
  def createKeys(n: Int, k: Int) = {
    val groupSize = n/k
    val rem = n % k
    val cumCounts = Array.tabulate(k){ i =>
      if (i < rem) (i + 1)*(groupSize + 1) else (i + 1)*groupSize + rem
    }
    val rng = new Random(seed)
    for (count <- n to 1 by -1) yield {
      val j = rng.nextInt(count)
      val i = cumCounts.iterator.zipWithIndex.find(_._1 > j).map(_._2).get
      for (s <- i until k) cumCounts(s) -= 1
    }
  }

  val keys = createKeys(data.length, k)
  for (i <- 0 until k) yield {
    val testIterable = new Iterable[T] {
      def iterator = (keys.iterator zip data.iterator).filter(_._1 == i).map(_._2)
    }
    val trainIterable = new Iterable[T] {
      def iterator = (keys.iterator zip data.iterator).filter(_._1 != i).map(_._2)
    }
    (testIterator, trainIterator)
  }
}

请注意我定义testIterabletrainIterable的方式。这使得你的测试/训练设置了懒惰和非记忆,我收集的是你想要的。

使用示例:

val data = 'a' to 'z'
for (((testData, trainData), index) <- testTrainDataList(data, 4).zipWithIndex) {
  println(s"index = $index")
  println("test: " + testData.mkString(", "))
  println("train: " + trainData.mkString(", "))
}

//index = 0
//test: i, l, o, q, v, w, y
//train: a, b, c, d, e, f, g, h, j, k, m, n, p, r, s, t, u, x, z
//
//index = 1
//test: a, d, e, h, n, r, z
//train: b, c, f, g, i, j, k, l, m, o, p, q, s, t, u, v, w, x, y
//
//index = 2
//test: b, c, m, t, u, x
//train: a, d, e, f, g, h, i, j, k, l, n, o, p, q, r, s, v, w, y, z
//
//index = 3
//test: f, g, j, k, p, s
//train: a, b, c, d, e, h, i, l, m, n, o, q, r, t, u, v, w, x, y, z