来自https://www.tensorflow.org/api_docs/python/tf/GraphKeys
GLOBAL_VARIABLES:共享的Variable对象的默认集合 跨分布式环境(模型变量是这些变量的子集)。 有关更多详细信息,请参见tf.compat.v1.global_variables。通常,所有 TRAINABLE_VARIABLES变量将位于MODEL_VARIABLES中,所有 MODEL_VARIABLES变量将位于GLOBAL_VARIABLES
TRAINABLE_VARIABLES:将是可变对象的子集 由优化师培训。有关更多信息,请参见tf.compat.v1.trainable_variables 详细信息。
据我所知TRAINABLE_VARIABLES
是GLOBAL_VARIABLES
的子集,那么GLOBAL_VARIABLES
还包含什么?
对于这个简单的示例语句Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES
也不成立:
IMAGE_HEIGHT = 5
IMAGE_WIDTH = 5
with tf.Graph().as_default():
with tf.variable_scope('my_scope', reuse=tf.AUTO_REUSE):
x_ph = tf.placeholder(
dtype=tf.float32,
shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
name='input'
)
x_tf = tf.layers.conv2d(x_ph, 32, 1, 1, padding='valid')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_np = np.random.rand(1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)
out_np = sess.run(x_tf, {x_ph:x_np})
print('out_np.shape', out_np.shape)
print('-'*60)
global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print('len(global_vars)', len(global_vars))
print('global_vars params:', sum([np.prod(var.shape) for var in global_vars]))
print(global_vars)
print('-'*60)
model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
print('len(model_vars)', len(model_vars))
print('model_vars params:', sum([np.prod(var.shape) for var in model_vars]))
print(model_vars)
print('-'*60)
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print('len(trainable_vars)', len(trainable_vars))
print('trainable_vars params:', sum([np.prod(var.shape) for var in trainable_vars]))
print(trainable_vars)
输出:
out_np.shape (1, 5, 5, 32)
------------------------------------------------------------
len(global_vars) 2
global_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
------------------------------------------------------------
len(model_vars) 0
model_vars params: 0
[]
------------------------------------------------------------
len(trainable_vars) 2
trainable_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
所以问题是:
为什么Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES
不适合本示例。
除GLOBAL_VARIABLES
之外,TRAINABLE_VARIABLES
还包含哪些其他变量?是TRAINABLE_VARIABLES
始终是GLOBAL_VARIABLES
的子集,还是它们可以部分相交?
答案 0 :(得分:2)
注意:所有这些仅适用于TF版本 1 ,因为所有变量集合均已弃用,并且(IIRC)不在TF v2中。 >
例如,除了TRAINABLE_VARIABLES之外,GLOBAL_VARIABLES还包含哪些其他变量?
global_step
是不可训练的全局变量。
这是一个变量,因为您在每个步骤都进行了更新,因此它不是可训练的,因为它不是优化过程的一部分(例如,不是为了最小化损失而改变的权重/偏差)。
是真的TRAINABLE_VARIABLES将始终是GLOBAL_VARIABLES的子集,还是它们只能部分相交?
原则上,两组可能只是部分相交,尽管这很奇怪。我能想到的一个例子是自定义的分布式培训环境,其中每台机器都有自己的优化器,并且某些可训练变量定义为局部变量(即,每台机器都有自己的副本,而这些副本未保存在同步)。为什么要这么做?没有线索。但从原则上讲,这是可能的。
我相信您引用的语句缺少重要的说明:您需要将变量放入MODEL_VARIABLES
集合中,默认情况下变量仅添加到GLOBAL_VARIABLES
中收集,如果trainable=True
也收集到TRAINABLE_VARIABLS
。 TF本身无法知道哪些变量是推理所必需的,而哪些变量仅是用于训练的(例如,带有辅助头的网络仅用于训练),因此它留给了网络架构师。附带说明,我从未见过该收藏集在任何地方都可以使用,而且我相信它目前尚未使用。