张量流中的局部变量是什么?

时间:2016-08-12 04:54:06

标签: tensorflow

Tensorflow已定义此API:

  

tf.local_variables()

     

返回使用collection=[LOCAL_VARIABLES]创建的所有变量。

     

返回:

     

本地变量对象的列表。

TensorFlow中的局部变量究竟是什么?有人可以举个例子吗?

3 个答案:

答案 0 :(得分:21)

简短回答:TF中的局部变量是使用collections=[tf.GraphKeys.LOCAL_VARIABLES]创建的任何变量。例如:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
  

LOCAL_VARIABLES:每个都是本地的Variable对象的子集   机。通常用于临时变量,如计数器。注意:   使用tf.contrib.framework.local_variable添加到此集合。

它们通常不会保存/恢复到检查点,并用于临时或中间值。

答案很长:这对我来说也是一个混乱的根源。在开始时我认为局部变量与local variable in almost any programming language的意思相同,但它不是一回事:

import tensorflow as tf

def some_func():
    z = tf.Variable(1, name='var_z')

a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
    c = tf.Variable(3, name='var_c')

with tf.variable_scope('bbb'):
    d = tf.Variable(3, name='var_d')

some_func()
some_func()

print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]

无论我尝试什么,我总是只收到全局变量:

['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]

tf.local_variables的文档没有提供很多细节:

  

局部变量 - 每个过程变量,通常不保存/恢复   检查点并用于临时或中间值。例如,   它们可以用作度量计算的计数器或数量   epochs这台机器已读取数据。 local_variable()自动   向GraphKeys.LOCAL_VARIABLES添加新变量。这方便   function返回该集合的内容。

但是在tf.Variable类中读取init方法的文档时,我发现在创建变量时,您可以通过分配{{1}列表来提供您想要的变量类型。 }}

可能的集合元素列表是here。所以要创建一个局部变量,你需要做这样的事情。您将在collections

列表中看到它
local_variables

答案 1 :(得分:18)

它与常规变量相同,但它与默认值(var containerName = "mycontainer"; // remove /myvd and SAS will work fine var containerReference = blobClient.GetContainerReference( containerName ); var blobName = "myvd/readme.txt"; //Your blob name is actually "myvd/readme.txt" var blobReference = await containerReference.GetBlobReferenceFromServerAsync( blobName ); )不同。该保护程序使用该集合来初始化要保存的默认变量列表,因此具有[DebuggerDisplay("ID:{ID},Customers:{Customers==null?(int?)null:Customers.Count}")]` class Project { int ID{get;set;} IList<Customer> Customers{get;set;} } 指定具有默认情况下不保存该变量的效果。

我在代码库中只看到一个使用它的地方,即GraphKeys.VARIABLES

local

答案 2 :(得分:5)

我认为,这里需要了解TensorFlow集合。

TensorFlow提供了集合,这些集合被称为张量或其他对象(例如tf.Variable实例)的列表。

以下是内置集合:

tf.GraphKeys.GLOBAL_VARIABLES               #=> 'variables'                                                                                                                                                                                 
tf.GraphKeys.LOCAL_VARIABLES                #=> 'local_variables'                                                                                                                                                                           
tf.GraphKeys.MODEL_VARIABLES                #=> 'model_variables'                                                                                                                                                                           
tf.GraphKeys.TRAINABLE_VARIABLES            #=> 'trainable_variables' 

通常,在创建变量时,可以通过将其作为传递给collections参数的集合之一来显式传递该集合,从而将其添加到给定的集合中。

从理论上讲,变量可以是内置或自定义集合的任意组合。但是,内置集合用于特定目的:

  • tf.GraphKeys.GLOBAL_VARIABLES
    • Variable()构造函数或get_variable()会自动将新变量添加到图集合GraphKeys.GLOBAL_VARIABLES中,除非显式传递了collections自变量且其中不包含{{1} }。
    • 按照惯例,这些变量在分布式环境之间共享(模型变量是这些变量的子集)。
    • 有关更多详细信息,请参见tf.global_variables()
  • tf.GraphKeys.TRAINABLE_VARIABLES
    • 在传递GLOBAL_VARIABLE(这是默认行为)时,trainable=True构造函数和Variable()自动向该图形集合添加新变量。但是,当然,您可以使用get_variable()参数来添加变量 转到任何所需的集合。
    • 按照惯例,这些是将由优化器训练的变量。
    • 有关更多详细信息,请参见tf.trainable_variables()
  • tf.GraphKeys.LOCAL_VARIABLES
    • 您可以使用tf.contrib.framework.local_variable()添加到此收藏集。但是当然,您可以使用collections参数将变量添加到 任何所需的集合。
    • 按照惯例,这些是每台计算机本地的变量。它们是每个过程变量,通常不保存/恢复到检查点,而是用于临时或中间值。例如,它们可用作计数器 用于度量计算或本机已读取数据的时期数。
    • 有关更多详细信息,请参见tf.local_variables()
  • tf.GraphKeys.MODEL_VARIABLES
    • 您可以使用tf.contrib.framework.model_variable()添加到此收藏集。但是当然,您可以使用collections参数将变量添加到 任何所需的集合。
    • 按照惯例,这些是模型中用于推理(前馈)的变量。
    • 有关更多详细信息,请参见tf.model_variables()

您还可以使用自己的收藏集。任何字符串都是有效的集合名称,无需显式创建集合。要在创建变量后将变量(或任何其他对象)添加到集合中,请调用tf.add_to_collection()

例如

collections
相关问题