获取与TensorFlow中特定张量相关的所有tf.Variables

时间:2016-06-11 15:01:03

标签: python tensorflow

我想创建一个实用程序函数,在其中我给它一个张量t和一个变量名n,函数返回(它存在)variable名称包含{ {1}}是n

图表的一部分
t

我想这样做的原因是使用Jupyter我经常在决定最佳结构时多次创建图形,但是张量名称不断改变为:def get_variable(t, n): #code return variable ,所以通过{{1}找到它们变得越来越难了。

如果我可以将搜索限制为与特定张量相关的变量会更容易,因为它的标识符是最新的重复。

1 个答案:

答案 0 :(得分:2)

(警告:我不打算回答有关创建实用程序功能的主要问题)

张量的名称如下:{name}_{repetition}:0不是因为您一次又一次地将这些张量添加到默认图

解决方案是始终指定使用graph.as_default()的图表:

graph = tf.Graph()
with graph.as_default():
  with tf.variable_scope('foo'):
    var = tf.get_variable('var', [])
    print var.name  # should be 'foo/var:0'
    tensor = tf.constant(2., name='tensor')
    print tensor.name  # should be 'foo/tensor:0'

如果再次运行,您会看到完全相同的结果,因为graph = tf.Graph()行会创建一个新图表。

graph = tf.Graph()
with graph.as_default():
  with tf.variable_scope('foo'):
    var = tf.get_variable('var', [])
    print var.name  # should be 'foo/var:0'
    tensor = tf.constant(2., name='tensor')
    print tensor.name  # should be 'foo/tensor:0'

缺点是现在你不能依赖默认图并且应该通过图来计算所有变量的列表:

print tf.all_variables()  # returns []

with graph.as_default():
  print tf.all_variables()  # returns [...] with all variables