tf.assign和赋值运算符之间的区别(=)

时间:2017-08-20 06:53:22

标签: tensorflow

我试图理解tf.assign和赋值运算符(=)之间的区别。我有三组代码

首先,使用简单的tf.assign

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  assign_op = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(assign_op)
    print a.eval()
    print a.eval()

输出预计为

2
2
2

其次,使用赋值运算符

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = a + 1
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print sess.run(a)
   print a.eval()
   print a.eval()

结果仍然是2,2,2。

第三,我使用两者 tf.assign和赋值运算符

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)
    print a.eval()
    print a.eval()

现在,输出变为2,3,4。

我的问题是

  1. 在使用(=)的第二个片段中,当我有sess.run(a)时,似乎我正在运行一个分配操作。那么“a = a + 1”在内部创建一个赋值操作,如assign_op = tf.assign(a,a + 1)?会话运行的op真的只是assign_op吗?但是当我运行a.eval()时,它不会继续增加a,因此看起来eval正在评估一个“静态”变量。

  2. 我不确定如何解释第3个片段。为什么两个evals增加了一个,但第二个片段中的两个evals没有?

  3. 感谢。

3 个答案:

答案 0 :(得分:4)

这里的主要困惑是,执行a = a + 1会将Python变量a重新分配给加法运算a + 1的结果张量。另一方面,tf.assign是用于设置TensorFlow变量值的操作。

a = tf.Variable(1, name="a")
a = a + 1

这相当于:

a = tf.add(tf.Variable(1, name="a"), 1)

考虑到这一点:

  

在使用(=)的第二个片段中,当我有sess.run(a)时,我似乎正在运行一个分配操作。那么" a = a + 1"在内部创建一个分配操作,如assign_op = tf.assign(a,a + 1)? [...]

看起来可能如此,但事实并非如此。如上所述,这只会重新分配Python变量。如果没有tf.assign或任何其他更改变量的操作,它将保持值1.每次评估a时,程序将始终计算a + 1 => 1 + 1

  

我不确定如何解释第3个片段。为什么这两个版本增加了一个,但第二个片段中的两个版本没有增加?

这是因为在第三个代码段中调用赋值张量上的eval()也会触发变量赋值(请注意,这与使用当前会话执行session.run(a)非常不同)。

答案 1 :(得分:0)

对于代码段1

with tf.Graph().as_default():
  a = tf.Variable(1, name="a_var")
  assign_op = tf.assign(a, tf.add(a,1,name='ADD'))

  b = tf.Variable(112)
  b = b.assign(a)  
  print(a)
  print(b)
  print(assign_op)  

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())    
    print (sess.run(a))      
    print ("assign_op : ",sess.run(assign_op))    
    print("         b :- ",b.eval())
    print (sess.run(a))
    print (sess.run(a))    
    print ("assign_op : ",sess.run(assign_op))
    print (sess.run(a))  
    print (sess.run(a))     
    writer = tf.summary.FileWriter("/tmp/log", sess.graph)
    writer.close()

此代码段1的o / p:

<tf.Variable 'a_var:0' shape=() dtype=int32_ref>
Tensor("Assign_1:0", shape=(), dtype=int32_ref)
Tensor("Assign:0", shape=(), dtype=int32_ref)
1
assign_op :  2
        b :-  2
2
2
assign_op :  3
3
3

have a look at tensorboard's computational graph

要点:

  1. 第一个变量&#39; a&#39;被评估,所以你得到o / p:1
  2. next sess.run(assign_op),execution =&gt; assign_op = tf.assign(a,tf.add(a,1,name =&#39; ADD&#39;)),它可以更新变量&#39;(= 2)并创建& #39; ASSIGN_OP&#39;这是张量型物体。
  3. 对于代码段2 see computational graph, you'll get the idea (请注意,没有用于分配操作的节点)

    with tf.Graph().as_default():
    a = tf.Variable(1, name="Var_a")
    just_a = a + 1  
    print(a)
    print(just_a)
    
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      print (sess.run(a))  
      print (sess.run(a))   
      print ("just_a : ",sess.run(just_a))    
      print (sess.run(a))
      print (sess.run(a))
      print ("just_a : ",sess.run(just_a)) 
      print (sess.run(a))  
      print (sess.run(a)) 
      writer = tf.summary.FileWriter("/tmp/log", sess.graph)
      writer.close()
    

    代码段2的o / p:

    enter code here
    <tf.Variable 'Var_a:0' shape=() dtype=int32_ref>
    Tensor("add:0", shape=(), dtype=int32)
    1
    1
    just_a :  2
    1
    1
    just_a :  2
    1
    1
    

    对于代码段3 Computational graph

    with tf.Graph().as_default():
    a = tf.Variable(1, name="Var_name_a")
    a = tf.assign(a, tf.add(a,5))
    
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())    
      print (sess.run(a))  
      print (sess.run(a))
      print ("        a : ",sess.run(a))
      print (sess.run(a))
      print (sess.run(a))
      print ("         a : ",sess.run(a))
      print (sess.run(a))  
      print (sess.run(a))        
      writer = tf.summary.FileWriter("/tmp/log", sess.graph)
      writer.close()
    

    o / p for snippet 3:

    enter code here
    6
    11
        a :  16
    21
    26
         a :  31
    36
    41
    

    现在,如果你看看这个片段的计算图,它看起来与片段1的计算图相似/完全相同。但这里的问题是代码a = tf.assign(a,tf.add(a,5)),而不是只更新变量&#39; a&#39;但也创造了另一个张量&#39; a&#39;再次。

    现在刚刚创建了&#39; a&#39;将由

    使用
    print (sess.run(a))
    

    这个&#39; a&#39;将是a = tf.assign(a,tf.add(a,5))

    &#39; a&#39;来自tf.add(a,5)只不过是&#39;(= 1)=&gt; a = tf.Variable(1,name =&#34; Var_name_a&#34;)...所以5 + 1 = 6被分配给原始&#39; a&#39;而这个原创的&#39; a&#39;被分配给新的&#39; a。

    我还有一个例子可以同时解释这个概念

    <强> check the graph here

    enter code here
    with tf.Graph().as_default():
      w = tf.Variable(10,name="VAR_W") #initial val = 2
    
      init_op = tf.global_variables_initializer()
    
     # Launch the graph in a session.
     with tf.Session() as sess:
        # Run the variable initializer.
        sess.run(init_op)
        print(w.eval())
        print(w) #type of 'w' before assign operation
    
        #CASE:1
        w = w.assign(w + 50)#adding 100 to var w
        print(w.eval())      
        print(w) #type of 'w' after assign operation
    
        # now if u try  =>  w = w.assign(w + 50), u will get error bcoz newly 
        created 'w' is considered here which don't have assign attribute
    
        #CASE:2    
        w = tf.assign(w, w + 100) #adding 100 to var w
        print(w.eval())  
        #CASE:3    
        w = tf.assign(w, w + 300) #adding 100 to var w
        print(w.eval())    
        writer = tf.summary.FileWriter("/tmp/log", sess.graph)
        writer.close()
    

    以上代码段的o / p:

    10
    <tf.Variable 'VAR_W:0' shape=() dtype=int32_ref>
    60
    Tensor("Assign:0", shape=(), dtype=int32_ref)
    210
    660
    

答案 2 :(得分:0)

首先,anwser并不十分精确。 IMO,在python对象和tf对象之间没有区别。它们都是由python GC管理的内存对象。

如果将第二个a更改为b,并打印出var,

In [2]: g = tf.Graph()

In [3]: with g.as_default():
   ...:     a = tf.Variable(1, name='a')
   ...:     b = a + 1
   ...:

In [4]: print(a)
<tf.Variable 'a:0' shape=() dtype=int32_ref>

In [5]: print(b)
Tensor("add:0", shape=(), dtype=int32)

In [6]: id(a)
Out[6]: 140253111576208

In [7]: id(b)
Out[7]: 140252306449616

ab不在内存中引用同一对象。

绘制计算图或内存图

第一行

# a = tf.Varaible(...
a -> var(a)

第二行

# b = a + 1
b -> add - var(a)
      |
       \-- 1

现在,如果将其替换回b = a + 1a = a + 1,则分配操作后的a指向tf.add对象,而不是变量{{1} }增加1。

运行a时,您将通过该sess.run运算符获取结果,而对原始add变量没有副作用。

另一方面,

a将具有在会话下更新图状态的副作用。