张量流中的tf.GraphKeys.GLOBAL_VARIABLES和tf.GraphKeys.TRAINABLE_VARIABLES有什么区别?

时间:2019-09-25 11:29:09

标签: python tensorflow

来自https://www.tensorflow.org/api_docs/python/tf/GraphKeys

  

GLOBAL_VARIABLES:共享的Variable对象的默认集合   跨分布式环境(模型变量是这些变量的子集)。   有关更多详细信息,请参见tf.compat.v1.global_variables。通常,所有   TRAINABLE_VARIABLES变量将位于MODEL_VARIABLES中,所有   MODEL_VARIABLES变量将位于GLOBAL_VARIABLES

     

TRAINABLE_VARIABLES:将是可变对象的子集   由优化师培训。有关更多信息,请参见tf.compat.v1.trainable_variables   详细信息。

据我所知TRAINABLE_VARIABLESGLOBAL_VARIABLES的子集,那么GLOBAL_VARIABLES还包含什么?

对于这个简单的示例语句Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES也不成立:

IMAGE_HEIGHT = 5
IMAGE_WIDTH = 5
with tf.Graph().as_default():
    with tf.variable_scope('my_scope', reuse=tf.AUTO_REUSE):
        x_ph = tf.placeholder(
                dtype=tf.float32,
                shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
                name='input'
            )

        x_tf = tf.layers.conv2d(x_ph, 32, 1, 1, padding='valid')

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        x_np = np.random.rand(1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)

        out_np = sess.run(x_tf, {x_ph:x_np})

        print('out_np.shape', out_np.shape)

        print('-'*60)
        global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        print('len(global_vars)', len(global_vars))
        print('global_vars params:', sum([np.prod(var.shape) for var in global_vars]))
        print(global_vars)

        print('-'*60)
        model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
        print('len(model_vars)', len(model_vars))
        print('model_vars params:', sum([np.prod(var.shape) for var in model_vars]))
        print(model_vars)

        print('-'*60)
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print('len(trainable_vars)', len(trainable_vars))
        print('trainable_vars params:', sum([np.prod(var.shape) for var in trainable_vars]))
        print(trainable_vars)

输出:

out_np.shape (1, 5, 5, 32)
------------------------------------------------------------
len(global_vars) 2
global_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
------------------------------------------------------------
len(model_vars) 0
model_vars params: 0
[]
------------------------------------------------------------
len(trainable_vars) 2
trainable_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]

所以问题是:

  1. 为什么Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES不适合本示例。

  2. GLOBAL_VARIABLES之外,TRAINABLE_VARIABLES还包含哪些其他变量?是TRAINABLE_VARIABLES始终是GLOBAL_VARIABLES的子集,还是它们可以部分相交?

1 个答案:

答案 0 :(得分:2)

注意:所有这些仅适用于TF版本 1 ,因为所有变量集合均已弃用,并且(IIRC)不在TF v2中。 >

从问题2开始:

  

除了TRAINABLE_VARIABLES之外,GLOBAL_VARIABLES还包含哪些其他变量?

例如,

global_step是不可训练的全局变量。 这是一个变量,因为您在每个步骤都进行了更新,因此它不是可训练的,因为它不是优化过程的一部分(例如,不是为了最小化损失而改变的权重/偏差)。

  

是真的TRAINABLE_VARIABLES将始终是GLOBAL_VARIABLES的子集,还是它们只能部分相交?

原则上,两组可能只是部分相交,尽管这很奇怪。我能想到的一个例子是自定义的分布式培训环境,其中每台机器都有自己的优化器,并且某些可训练变量定义为局部变量(即,每台机器都有自己的副本,而这些副本未保存在同步)。为什么要这么做?没有线索。但从原则上讲,这是可能的。

然后关于问题1:

我相信您引用的语句缺少重要的说明:需要将变量放入MODEL_VARIABLES集合中,默认情况下变量仅添加到GLOBAL_VARIABLES中收集,如果trainable=True也收集到TRAINABLE_VARIABLS。 TF本身无法知道哪些变量是推理所必需的,而哪些变量仅是用于训练的(例如,带有辅助头的网络仅用于训练),因此它留给了网络架构师。附带说明,我从未见过该收藏集在任何地方都可以使用,而且我相信它目前尚未使用。