如何将变量复制到张量流中的另一个图形

时间:2017-07-24 05:52:04

标签: tensorflow

我希望将tensorflow varialbe从旧图复制到新图,然后删除旧图并将新图作为默认图。下面是我的代码,但它引发了一个AttributeError:'Graph'对象没有属性'variable1'。我是tensorflow的新手。任何人都可以给我一个具体的例子吗?谢谢。

import tensorflow as tf
import numpy as np

graph1 = tf.Graph()
with graph1.as_default():
    variable1 = tf.Variable(np.array([2.,-1.]), name='x1')
    initialize = tf.initialize_all_variables()

sess1=tf.Session(graph=graph1)
sess1.run(initialize)
print('sess1:',variable1.eval(session=sess1))

graph2 = tf.Graph()
with graph2.as_default():
    variable2=tf.contrib.copy_graph.copy_variable_to_graph(graph1.variable1,graph2)
sess2=tf.Session(graph=graph2)
#I want to remove graph1 and sess1, ande make graph2 and sess2 as default here.
print('sess2:',variable2.eval(session=sess2))

1 个答案:

答案 0 :(得分:1)

  1. tf.initialize_all_variables()已被删除。请改用tf.global_variables_initializer()

  2. 您不需要graph1.variable1。只需通过variable1

  3. 您忘记在第二个会话中初始化变量:

    initialize2 = tf.global_variables_initializer()
    sess2=tf.Session(graph=graph2)
    sess2.run(initialize2)
    
  4. 所以你的代码应该是这样的:

    import tensorflow as tf
    import numpy as np
    
    graph1 = tf.Graph()
    with graph1.as_default():
        variable1 = tf.Variable(np.array([2.,-1.]), name='x1')
        initialize = tf.global_variables_initializer()
        sess1=tf.Session(graph=graph1)
        sess1.run(initialize)
        print('sess1:',variable1.eval(session=sess1))
    
    graph2 = tf.Graph()
    with graph2.as_default():
        variable2=tf.contrib.copy_graph.copy_variable_to_graph(variable1,graph2)
        initialize2 = tf.global_variables_initializer()
        sess2=tf.Session(graph=graph2)
        sess2.run(initialize2)
        #I want to remove graph1 and sess1, ande make graph2 and sess2 as default here.
        print('sess2:',variable2.eval(session=sess2))