张量流中不同范围的子网重量共享

时间:2017-10-05 12:07:35

标签: python tensorflow

使用tensorflow,我试图在不同的变量范围内从相同的网络共享相同的权重以节省内存。但是,似乎没有简单的方法可以做到这一点。我准备了一个小代码示例,以较小的比例说明我想用更大的子网做什么:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    with tf.variable_scope("super_scope_one"):
        scope1 = tf.variable_scope("sub_scope_one")
        with scope1:
            number_one = tf.get_variable("number_one", shape=[1],
                                         initializer=tf.ones_initializer)
    with tf.variable_scope("super_scope_two"):
        with tf.variable_scope("sub_scope_one", reuse=True) as scope2:
            # Here is the problem.
            # scope1.reuse_variables() # this crashes too if reuse=None.
            number_one = tf.get_variable("number_one", shape=[1])
        with tf.variable_scope("sub_scope_two"):
            number_two = tf.get_variable("number_two", shape=[1],
                                         initializer=tf.ones_initializer)
        number_three = number_one + number_two

    init_op = tf.global_variables_initializer()

with tf.Session(graph=graph):
    init_op.run()
    print(number_three.eval())

有没有办法在两个子范围内共享变量,而无需删除 上面的范围?如果没有,那么为什么这会是一个坏主意呢?

1 个答案:

答案 0 :(得分:1)

您只需在number_one中定义"super_scope_one"一次,并在"super_scope_two"中使用它。

不同范围内的两个变量可以一起使用。见下文:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    with tf.variable_scope("super_scope_one"):
        scope1 = tf.variable_scope("sub_scope_one")
        with scope1:
            number_one = tf.get_variable("number_one", shape=[1],
                                         initializer=tf.ones_initializer)
    with tf.variable_scope("super_scope_two"):
        with tf.variable_scope("sub_scope_two"):
            number_two = tf.get_variable("number_two", shape=[1],
                                         initializer=tf.ones_initializer)
        number_three = number_one + number_two

    init_op = tf.global_variables_initializer()

    with tf.Session(graph=graph):
        init_op.run()
        print(number_three.eval())

返回[2]