如何使tensorflow变量“本地”?

时间:2017-04-13 19:25:12

标签: python tensorflow

似乎tf.get_variable()tf.Variable()所产生的张量流变量是全局变量。发生在我身上的情况如下:假设我制作了以下两个文件:

  1. main.py

    from prac_var import defineVar
    for i in range(1000):
        defineVar()
    
  2. prac_var.py

    import tensorflow as tf
    def defineVar():
        with tf.variable_scope('weight'):    
            W = tf.Variable(tf.zeros([1,1]),name='W')
            print('\n',tf.trainable_variables())
    
  3. 现在,如果我运行main.py,它会生成

    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>]
    
    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>, <tf.Variable 'weight_1/W:0' shape=(1, 1) dtype=float32_ref>]
    
    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>, <tf.Variable 'weight_1/W:0' shape=(1, 1) dtype=float32_ref>, <tf.Variable 'weight_2/W:0' shape=(1, 1) dtype=float32_ref>]
    ...
    

    而我真正想要的是

    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>]
    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>]
    [<tf.Variable 'weight/W:0' shape=(1, 1) dtype=float32_ref>]
    ...
    

    我怎样才能以一种非常重要的方式解决这个问题?

2 个答案:

答案 0 :(得分:0)

首先,我想知道你是否理解Tensorflow首先构造一个计算图,然后用你定义的图进行所有计算。如果您知道变量的名称,则可以使用tf.get_variable()从任何地方到达所有变量...

如果你在两个不同的地方获得权重,两次都试图获得变量W你重复使用&#34;这些重量。这就是为什么引入可变范围的原因:  https://www.tensorflow.org/programmers_guide/variable_scope#variable_scope_example

如果你想要两个不同的权重,你可以说:

with tf.variable_scope('mnistweights'):
  Wmnist = tf.get_variable('W',...)
with tf.variable_scope('exampleweights'):
  Wtest = tf.get_variable('W',...)

现在您拥有名称为mnistweights / W和exampleweights / W的变量。

希望你能更好地理解它!

答案 1 :(得分:0)

我意外地找到了解决问题的方法。只需添加一行tf.reset_default_graph()

  1. main.py

    from prac_var import defineVar
    for i in range(1000):
        defineVar()
    
  2. prac_var.py

    import tensorflow as tf
    def defineVar():
        tf.reset_default_graph()
        with tf.variable_scope('weight'):    
            W = tf.Variable(tf.zeros([1,1]),name='W')
            print('\n',tf.trainable_variables())