如果在张量流中的mnist上有一个简单的小批量梯度下降问题(如此tutorial),我该如何单独检索批处理中每个示例的渐变。
tf.gradients()
似乎返回批处理中所有示例的平均梯度。有没有办法在聚合之前检索渐变?
编辑:这个答案的第一步是弄清楚张力流在哪个点上平均了批次中的例子的梯度。我认为这发生在_AggregatedGrads,但事实并非如此。有什么想法吗?
答案 0 :(得分:7)
tf.gradients
返回与损失相关的渐变。这意味着如果您的损失是每个示例损失的总和,那么梯度也是每个示例损失梯度的总和。
总结是隐含的。例如,如果您希望最小化Wx-y
误差的平方范围总和,则W
的梯度为2(WX-Y)X'
,其中X
是一组观察值Y
1}}是一批标签。你永远不会明确地形成"每个例子"您稍后总结的渐变,因此在渐变管道中删除某个阶段并不是一件简单的事情。
获得k
每个示例损失渐变的简单方法是使用大小为1的批次并执行k
次传递。 Ian Goodfellow wrote up如何在一次传递中获得所有k
渐变,为此您需要明确指定渐变而不依赖tf.gradients
方法
答案 1 :(得分:1)
在修补一段时间后,部分回答我自己的问题。看来,通过执行以下操作,可以在批量工作的同时操作渐变:
custagg_gradients(
ys=[cross_entropy[i] for i in xrange(batch_size)],
xs=variables.trainable_variables(),
aggregation_method=CUSTOM,
gradient_factors=gradient_factors
)
但是这可能与每个示例的单个传递具有相同的复杂性,我需要检查渐变是否正确: - )。
答案 2 :(得分:0)
在聚合之前检索渐变的一种方法是使用grads_ys
参数。这里有一个很好的讨论:
Use of grads_ys parameter in tf.gradients - TensorFlow
编辑:
我最近没有和Tensorflow合作过,但这是一个开放的问题,跟踪计算非聚合渐变的最佳方法:
https://github.com/tensorflow/tensorflow/issues/675
用户(包括我自己)提供了大量示例代码解决方案,您可以根据自己的需要进行尝试。