我想进行k折交叉验证。基本上我们给了一堆数据allData
。假设我们将输入划分为" k"群集并将其放入groups
。
所需的输出为trainAndTestDataList: List[(Iterable[T], Iterable[T])]
,其中List
的大小为" k"。 trainAndTestDataList
的" i"元素是像(A,B)这样的元组,其中A应该是groups
的" i"元素和B应该是groups
的所有元素,除了" i" th,连接。
有效实施此方法的任何想法?
val allData: Iterable[T] = ... // we get the data from somewhere
val groupSize = Math.ceil(allData.size / k).toInt
val groups = allData.grouped(groupSize).toList
val trainAndTestDataList = ... // fill out this part
要记住的一件事是allData
可能会很长,但是" k"非常小(比如5)。因此,将所有数据向量保持为Iterator
(而非List
,Seq
等)非常重要。
更新:这是我的方式(我对此并不满意):
val trainAndTestDataList = {
(0 until k).map{ fold =>
val (a,b) = groups.zipWithIndex.partition{case (g, idx) => idx == fold}
(a.unzip._1.flatten.toIterable, b.unzip._1.flatten.toIterable)
}
}
我不喜欢它的原因:
partition
unzip
,._1
和flatten
a
后扭曲。我认为一个人应该能够做得更好。 Iterable[T]
是a.unzip._1.flatten.
,但我认为List[T]
的输出是{{1}}。这不好,因为此列表中元素的数量可能非常大。 答案 0 :(得分:2)
您可以尝试该操作
\1
此分组与implicit class TeeSplitOp[T](data: Iterable[T]) {
def teeSplit(count: Int): Stream[(Iterable[T], Iterable[T])] = {
val size = data.size
def piece(i: Int) = i * size / count
Stream.range(0, size - 1) map { i =>
val (prefix, rest) = data.splitAt(piece(i))
val (test, postfix) = rest.splitAt(piece(i + 1) - piece(i))
val train = prefix ++ postfix
(test, train)
}
}
}
一样懒惰,splitAt
在您的集合类型中。
您可以尝试
++
答案 1 :(得分:0)
我相信这应该有效。它还以合理有效的方式处理随机化(不要忽略这一点!),即使用随机shuffle /置换的更天真的方法所需的O(n)而不是O(n log(n))数据。
import scala.util.Random
def testTrainDataList[T](
data: Seq[T],
k: Int,
seed: Long = System.currentTimeMillis()
): Seq[(Iterable[T], Iterable[T])] = {
def createKeys(n: Int, k: Int) = {
val groupSize = n/k
val rem = n % k
val cumCounts = Array.tabulate(k){ i =>
if (i < rem) (i + 1)*(groupSize + 1) else (i + 1)*groupSize + rem
}
val rng = new Random(seed)
for (count <- n to 1 by -1) yield {
val j = rng.nextInt(count)
val i = cumCounts.iterator.zipWithIndex.find(_._1 > j).map(_._2).get
for (s <- i until k) cumCounts(s) -= 1
}
}
val keys = createKeys(data.length, k)
for (i <- 0 until k) yield {
val testIterable = new Iterable[T] {
def iterator = (keys.iterator zip data.iterator).filter(_._1 == i).map(_._2)
}
val trainIterable = new Iterable[T] {
def iterator = (keys.iterator zip data.iterator).filter(_._1 != i).map(_._2)
}
(testIterator, trainIterator)
}
}
请注意我定义testIterable
和trainIterable
的方式。这使得你的测试/训练设置了懒惰和非记忆,我收集的是你想要的。
使用示例:
val data = 'a' to 'z'
for (((testData, trainData), index) <- testTrainDataList(data, 4).zipWithIndex) {
println(s"index = $index")
println("test: " + testData.mkString(", "))
println("train: " + trainData.mkString(", "))
}
//index = 0
//test: i, l, o, q, v, w, y
//train: a, b, c, d, e, f, g, h, j, k, m, n, p, r, s, t, u, x, z
//
//index = 1
//test: a, d, e, h, n, r, z
//train: b, c, f, g, i, j, k, l, m, o, p, q, s, t, u, v, w, x, y
//
//index = 2
//test: b, c, m, t, u, x
//train: a, d, e, f, g, h, i, j, k, l, n, o, p, q, r, s, v, w, y, z
//
//index = 3
//test: f, g, j, k, p, s
//train: a, b, c, d, e, h, i, l, m, n, o, q, r, t, u, v, w, x, y, z