张量流中tf.GraphKeys.TRAINABLE_VARIABLES和tf.GraphKeys.UPDATE_OPS之间的区别是什么?

时间:2018-01-15 09:47:32

标签: tensorflow

这是张量流中tf.GraphKeys的doc,例如TRAINABLE_VARIABLES:将由优化器训练的Variable对象的子集。

我知道tf.get_collection(),可以找到你想要的张量。

使用tensorflow.contrib.layers.batch_norm()时,参数updates_collections默认值为GraphKeys.UPDATE_OPS

我们如何理解这些集合及其差异。

此外,我们可以在ops.py中找到更多信息。

2 个答案:

答案 0 :(得分:15)

这是两件不同的事情。

TRAINABLE_VARIABLES

TRAINABLE_VARIABLES变量或训练参数的集合,应在最小化损失时进行修改。例如,这些可以是确定由网络中的每个节点执行的功能的权重。

如何将变量添加到此集合中?当您使用tf.get_variable定义新变量时会自动发生这种情况,除非您指定

tf.get_variable(..., trainable=False)

您希望变量什么时候无法解决?这种情况时有发生。例如,有时您会想要使用两步法,首先在大型通用数据集上训练整个网络,然后在与您的问题特别相关的较小数据集上微调网络。在这种情况下,您可能只想微调网络的一部分,例如最后一层。将一些变量指定为无法解释是实现此目的的一种方法。

UPDATE_OPS

UPDATE_OPS ops 的集合(图表运行时执行的操作,如乘法,ReLU等),而不是变量。具体来说,此集合维护一个需要在每个培训步骤之前运行的操作列表。

如何将操作添加到此集合中? 根据定义,update_ops通过损失最小化发生在常规训练流程之外,因此通常只有在特殊情况下才会将ops添加到此集合中。例如,在执行批量标准化时,您希望在每个培训步骤之前重新计算批处理均值和方差,这就是它的完成方式。使用tf.contrib.layers.batch_norm的批量规范化机制在this article中有更详细的描述。

答案 1 :(得分:0)

不同意上一个答案。
实际上,所有内容都是tensorflow中的OP,TRAINABLE_VARIABLES集合中的变量也是OPs,由OP tf.get_variabletf.Variable创建。

对于UPDATE_OPS集合,通常包括在tf.layers.batch_norm函数中创建的移动平均值和移动方差。这些操作也可以视为变量,因为它们的值会在每个训练步骤中更新,就像权重和偏差一样。

主要区别在于trainable变量参与后退propagation的过程,而UPDATE_OPS中的变量不参与。它们仅在测试模式下参与推理过程,因此会在UPDATE_OPS中的这些变量上计算出网格。