除了使用tf.global_variables_initializer()之外,如何初始化未保存的tensorflow变量

时间:2017-05-30 00:07:37

标签: python tensorflow save

我正在研究如何在tensorflow中保存/加载特定变量。

我可以毫无问题地加载和保存特定变量,但是,我无法弄清楚如何在不使用

的情况下初始化剩余的未保存变量
sess.run(tf.global_variables_initializer())   

然后用:

覆盖保存的变量
new_saver.restore(sess,'my_test_model2')

这可以正常工作并初始化未保存的变量(w2)并恢复已保存的变量(w1),但看起来非常笨拙且不讽刺。

我想知道如何摆脱

tf.global_variables_initializer()

,在我恢复w1变量的最后,到pythonic工作的东西。

我尝试了sess.run(tf.variables_initializer([w2]))并得到了输入:“^ w2 / Assign”不是此图的元素。)

我也试过sess.run(tf.variables_initializer(["w2:0"]))  并且得到了AttributeError:'str'对象没有属性'initializer'     导入tensorflow为tf

print(tf.__version__)
w1 = tf.Variable(tf.linspace(0.0, 0.5, 6), name="w1")
w2 = tf.Variable(tf.linspace(1.0, 5.0, 6), name="w2")
saver = tf.train.Saver({'w1':w1})
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for v in tf.global_variables():
      print (v.name)

print(sess.run(["w1:0"]))
print(sess.run(["w2:0"]))

saver.save(sess, 'my_test_model')

tf.reset_default_graph()

print ('-'*80 )

w1 = tf.Variable(tf.linspace(10.0, 50.0, 6), name="w1")
w2 = tf.Variable(tf.linspace(100.0, 500.0, 6), name="w2")
saver = tf.train.Saver({'w1':w1})
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for v in tf.global_variables():
      print (v.name)

print(sess.run(["w1:0"]))
print(sess.run(["w2:0"]))

saver.save(sess, 'my_test_model2')  

tf.reset_default_graph()

print ('-'*80 )
print("Let's load w1 \n")  

with tf.Session() as sess:
  # Loading the model structure from 'my_test_model.meta'
  new_saver = tf.train.import_meta_graph('my_test_model.meta')
  # I do this to make sure w1:0 and w2:0 are variables
  for v in tf.global_variables():
        print (v.name)  
  sess.run(tf.global_variables_initializer())  #<----- line I want to make more pythonic
#   sess.run(tf.variables_initializer([w2]))  # input: "^w2/Assign" is not an element of this graph.)
#   sess.run(tf.variables_initializer(["w2:0"])) #AttributeError: 'str' object has no attribute 'initializer'

# Loading the saved "w1" Variable
  new_saver.restore(sess,'my_test_model2')

  print(sess.run(["w1:0"]))
  print(sess.run(["w2:0"]))    

1 个答案:

答案 0 :(得分:1)

最后看了之后:

In TensorFlow is there any way to just initialize uninitialised variables?

我喜欢https://stackoverflow.com/users/1090562/salvador-dali回答并将其修改为使用itertools.compress,如果变量超过少数,则速度要快得多。

def initialize_uninitialized_vars(sess):
    from itertools import compress
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([~(tf.is_variable_initialized(var)) \
                                   for var in global_vars])
    not_initialized_vars = list(compress(global_vars, is_not_initialized))

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))

我的代码变为:

with tf.Session() as sess:
  # Loading the model structure from 'my_test_model.meta'
  new_saver = tf.train.import_meta_graph('my_test_model.meta')  

  # Loading the saved "w1" Variable
  new_saver.restore(sess,'my_test_model2')

  # initialize the unitialized variables
  initialize_uninitialized_vars(sess)

  print(sess.run(["w1:0"]))
  print(sess.run(["w2:0"]))