我有以下方法,这是执行分层k折交叉验证的逻辑的一部分。
private static IEnumerable<IEnumerable<int>> GenerateFolds(
IClassificationProblemData problemData, int numberOfFolds)
{
IRandom random = new MersenneTwister();
IEnumerable<double> values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
var valuesIndices =
problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v });
IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass =
valuesIndices.GroupBy(x => x.Value, x => x.Index)
.Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
var enumerators = foldsByClass.Select(x => x.GetEnumerator()).ToList();
while (enumerators.All(e => e.MoveNext()))
{
var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
yield return fold.ToList();
}
}
折叠代:
private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(
IEnumerable<T> values, int valuesCount, int numberOfFolds)
{
// number of folds rounded to integer and remainder
int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds;
int start = 0, end = f;
for (int i = 0; i < numberOfFolds; ++i)
{
if (r > 0)
{
++end;
--r;
}
yield return values.Skip(start).Take(end - start);
start = end;
end += f;
}
}
通用GenerateFolds<T
方法只是根据指定的折叠次数将IEnumerable<T>
拆分为IEnumerable
的序列。例如,如果我有101个训练样本,那么它将产生11倍大小和9倍大小10倍。
上面的方法根据类值对样本进行分组,将每个组拆分为指定的折叠数,然后将按类折叠连接到最终折叠中,确保类标签的分布相同。
我的问题是关于yield return fold.ToList()
行。实际上,该方法可以正常工作,但是如果我删除ToList()
,则结果不再正确。在我的测试用例中,我有641个训练样本和10个折叠,这意味着第一个折叠尺寸为65,剩余折叠尺寸为64.但是当我移除ToList()
时,所有折叠尺寸为64和类标签未正确分发。有什么想法吗?谢谢。
答案 0 :(得分:1)
让我们考虑什么是fold
变量:
var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
这不是查询执行的结果。这是查询定义。因为SelectMany
和OrderBy
都是具有延迟执行方式的运算符。因此,它只是保存了有关从所有枚举器展平当前项目并以随机顺序返回它们的知识。我突出显示了单词 current ,因为它是查询执行时的当前项目。
现在让我们考虑一下这个查询的执行时间。 GenerateFolds
方法执行的结果是IEnumerable
查询的IEnumerable<int>
。以下代码不执行任何查询:
var folds = GenerateFolds(indices, values, numberOfFolds);
这又是一个查询。您可以通过调用ToList()
或枚举它来执行它:
var f = folds.ToList();
但即使是现在内部查询也没有执行。它们全部归还,但未执行。即在将查询保存到列表while
时,GenerateFolds
中的f
循环已执行。并且e.MoveNext()
已被多次调用,直到您退出循环:
while (enumerators.All(e => e.MoveNext()))
{
var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
yield return fold;
}
那么,f
持有什么?它包含查询列表。因此你得到了所有这些,当前项是每个枚举器中的最后一项(记住 - 我们在这个时间点完全迭代while
循环)。但是这些查询都没有执行!在这里执行第一个:
f[0].Count()
您可以获得第一个查询返回的项目数(在问题的顶部定义)。但是因此您已经枚举了所有查询,当前项目是最后一项。并且你得到了最后一项中的索引数。
现在看看
folds.First().Count()
此处,您不会枚举所有查询以将其保存在列表中。即while
循环仅执行一次,当前项目是第一项。这就是为什么你有第一项索引的原因。这就是为什么这些价值观不同。
最后一个问题 - 当您在ToList()
循环中添加while
时,为什么一切正常。答案非常简单 - 执行每个查询。并且您有索引列表而不是查询定义。每次查询都在每次迭代时执行,因此当前项始终不同。你的代码工作正常。