Tensorflow:按名称获取所有重量张量

时间:2018-05-03 15:10:35

标签: python tensorflow

我加载图表并想要访问图表中定义的权重h1h2h3

我可以通过以下方式轻松地为每个重量张量h执行此操作:

sess = tf.Session()
graph = tf.get_default_graph()
h1 = sess.graph.get_tensor_by_name("h1:0")
h2 = sess.graph.get_tensor_by_name("h2:0")

我不喜欢这种方法,因为对于大图来说它会很难看。我更喜欢像所有权重张量的循环一样将它们放入列表中。

我在Stack Overflow上找到了另外两个问题(herehere),但是他们没有帮我解决这个问题。

我尝试了以下方法,它有两个问题:

num_weight_tensors = 3
weights = []
for w in range(num_weight_tensors):
    weights.append(sess.graph.get_tensor_by_name("h1:0"))
print(weights)

第一个问题:我必须在图表中定义权重张量的数量,这使得代码不灵活。第二个问题:get_tensor_by_name()的论点是静态的。

有没有办法获得所有张量并将它们放入列表?

2 个答案:

答案 0 :(得分:1)

如果您只关心可以优化的权重,可以致电let listItems = document.querySelectorAll('input[type="hidden"] RID'); and let listItems = document.querySelectorAll('input[type="hidden"] #RID'); 。它返回tf.trainable_variables()参数设置为trainable的所有变量的列表。

True

打印:

tf.reset_default_graph()

# These can be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))

# These cannot be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="n{}".format(i), trainable=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = tf.get_default_graph()
    for t_var in tf.trainable_variables():
        print(t_var)

另一方面,<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref> <tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref> <tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref> <tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref> <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref> 返回所有变量的列表:

tf.global_variables()
for g_var in tf.global_variables():
    print(g_var)

<强>更新

为了更好地控制您希望收到的变量,有几种方法可以过滤它们。一种方法是openmark建议的方式。在这种情况下,您可以根据变量范围前缀过滤它们。

但是,如果这还不够,例如,如果您希望同时访问多个组,则还有其他方法。您只需按名称过滤它们,即:

<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n4:0' shape=(32, 32) dtype=float32_ref>

但是,您必须了解tensorflow变量的命名约定。那就是for g_var in tf.global_variables(): if g_var.name.beginswith('h'): print(g_var) 后缀,例如,变量范围前缀等等。

第二种方式,更少涉及,是创建自己的集合。例如,如果我对以2整除的数字结尾的变量以及代码中的其他地方感兴趣,我感兴趣的是名称以可被4整除的数字结尾的变量,我可以这样做:

:0

然后我可以简单地调用# These can be optimized for i in range(5): h_var = tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i)) if i % 2 == 0: tf.add_to_collection('vars_divisible_by_2', h_var) if i % 4 == 0: tf.add_to_collection('vars_divisible_by_4', h_var) 函数:

tf.get_collection()
tf.get_collection('vars_divisible_by_2)

[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]
tf.get_collection('vars_divisible_by_4'):

答案 1 :(得分:1)

您可以尝试tf.get_collection()

tf.get_collection(
key,
scope=None)

它返回keyscope指定的集合中的项目列表。 key是标准图集tf.GraphKeys中的密钥,例如,tf.GraphKeys.TRAINABLE_VARIABLES指定由优化程序训练的变量子集,而tf.GraphKeys.GLOBAL_VARIABLES指定全局变量列表,包括不可训练的。检查上面的链接以获取可用密钥类型的列表。您还可以指定scope参数来过滤结果列表以仅返回特定名称范围中的项目,这是一个小示例:

with tf.name_scope("aaa"):
    aaa1 = tf.Variable(tf.zeros(shape=(1,2,3)), name="aaa1")


with tf.name_scope("bbb"):
    bbb1 = tf.Variable(tf.zeros(shape=(4,5,6)), name="bbb1", trainable=False)

for item in  tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES):
    print(item)
# >>> <tf.Variable 'aaa/aaa1:0' shape=(1, 2, 3) dtype=float32_ref>

for item in  tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
# >>> <tf.Variable 'aaa/aaa1:0' shape=(1, 2, 3) dtype=float32_ref>
# >>> <tf.Variable 'bbb/bbb1:0' shape=(4, 5, 6) dtype=float32_ref>

for item in  tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope="bbb"):
    print(item)
# >>> <tf.Variable 'bbb/bbb1:0' shape=(4, 5, 6) dtype=float32_ref>