培训期间如何在Keras中使用optimizers.SGD的get_updates()?

时间:2019-03-08 07:25:36

标签: tensorflow keras

我不熟悉Keras的内部工作原理,很难理解Keras在培训过程中如何使用Optimizer.SGD的get_updates()功能。

我在互联网上搜索了很长时间,但只得到了很少的细节。具体来说,我的理解是SGD的参数/权重更新规则是在get_updates()函数中定义的。但是看来get_updates()并不是在训练期间的每次迭代中都被字面意思调用;否则,“时刻”将不会从一个迭代到下一个迭代进行以正确实现动量,因为在每次调用c.f中都会重置该动量。 optimizers.py:

shapes = [K.get_variable_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + moments
for p, g, m in zip(params, grads, moments):
    v = self.momentum * m - lr * g  # velocity
    self.updates.append(K.update(m, v))

https://github.com/keras-team/keras/issues/7502中所指出,get_updates()仅定义“符号计算图”。我不确定那是什么意思。有人可以给出更详细的解释吗?

例如,在一次迭代中计算出的“ v”如何在下一次迭代中传递给“矩”以实现动量?如果有人可以向我介绍有关此工作原理的教程,我也将不胜感激。

非常感谢! (顺便说一句,如果有关系,我正在使用张量流。)

1 个答案:

答案 0 :(得分:1)

get_updates()定义用于更新渐变的图形操作。 当评估图形进行训练时,它将看起来像这样:

  • 前向通过计算预测值
  • 损失计算成本
  • 向后传递计算梯度
  • 渐变已更新

更新渐变本身就是一个图形计算;即您引用的代码片段通过指定所涉及的张量以及发生了哪些数学运算来定义如何执行运算。此时,数学运算本身尚未发生。

moments是上面代码中定义的张量向量。该代码创建了一个图形操作,用于更新每个moments元素。

图形的每次迭代都会运行此更新操作。

以下链接试图解释TensorFlow中计算图的概念: https://www.tensorflow.org/guide/graphs

Keras使用相同的基本思想,但使用户不必处理底层细节。在传统的TensorFlow 1.0 API中定义模型需要更高的详细程度。