我在Apache Spark代码源
中遇到了this lineval (gradientSum, lossSum, miniBatchSize) = data
.sample(false, miniBatchFraction, 42 + i)
.treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
seqOp = (c, v) => {
// c: (grad, loss, count), v: (label, features)
val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
(c._1, c._2 + l, c._3 + 1)
},
combOp = (c1, c2) => {
// c: (grad, loss, count)
(c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
}
)
我在读这篇文章时遇到了一些麻烦:
treeAggregate
工作原理的内容,这些参数的含义是什么。 .treeAggregate
似乎在方法名称后面有两个()()。这意味着什么?这是一些我不理解的特殊scala语法。 这句话必须非常先进。我无法解读这个。
答案 0 :(得分:16)
treeAggregate
是aggregate
的专用实现,它将combine函数迭代地应用于分区子集。这样做是为了防止将所有部分结果返回给驱动程序,其中单个传递减少将像经典aggregate
那样发生。
出于所有实际目的,treeAggregate
遵循与此答案中解释的aggregate
相同的原则:Explain the aggregate functionality in Python,但需要额外的参数来指示部分聚合的深度水平。
让我试着解释一下这里发生了什么:
对于聚合,我们需要零,组合器函数和reduce函数。
aggregate
使用currying独立于combine和reduce函数指定零值。
然后我们可以像这样剖析上述功能。希望这有助于理解:
val Zero: (BDV, Double, Long) = (BDV.zeros[Double](n), 0.0, 0L)
val combinerFunction: ((BDV, Double, Long), (??, ??)) => (BDV, Double, Long) = (c, v) => {
// c: (grad, loss, count), v: (label, features)
val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
(c._1, c._2 + l, c._3 + 1)
val reducerFunction: ((BDV, Double, Long),(BDV, Double, Long)) => (BDV, Double, Long) = (c1, c2) => {
// c: (grad, loss, count)
(c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
}
然后我们可以用更易消化的形式重写对treeAggregate
的调用:
val (gradientSum, lossSum, miniBatchSize) = treeAggregate(Zero)(combinerFunction, reducerFunction)
此表单会将生成的元组“提取”到指定值gradientSum, lossSum, miniBatchSize
中以供进一步使用。
请注意,treeAggregate
需要使用默认值depth = 2
声明的附加参数depth
,因此,在此特定调用中未提供该参数,它将采用该默认值。