如何恢复保存的张量流模型?

时间:2017-05-07 09:45:10

标签: save restore

有两个python文件,第一个用于保存张量流 模型。第二个是恢复已保存的模型。

问题:

  1. 当我一个接一个地运行这两个文件时,没关系。

  2. 当我运行第一个时,重新启动编辑并运行第二个编辑 告诉我w1没有定义?

  3. 我想做的是:

    1. 保存张量流模型

    2. 恢复已保存的模型

    3. 它有什么问题?谢谢你的帮助?

      model_save.py

      import tensorflow as tf
      w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
      w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
      saver = tf.train.Saver()
      
      with tf.Session() as sess: 
      sess.run(tf.global_variables_initializer())
      saver.save(sess, 'SR\\my-model')
      

      model_restore.py

      import tensorflow as tf
      
      with tf.Session() as sess:    
      saver = tf.train.import_meta_graph('SR\\my-model.meta')
      saver.restore(sess,'SR\\my-model')
      print (sess.run(w1))
      

      enter image description here

1 个答案:

答案 0 :(得分:3)

简单来说,你应该使用

print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))

而不是 model_restore.py 文件中的print (sess.run(w1))

<强> model_save.py

import tensorflow as tf
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(w1_node.eval()) # [ 0.43350926  1.02784836]
  #print(w1.eval()) # NameError: name 'w1' is not defined
  saver.save(sess, 'my-model')

w1_node仅在 model_save.py 中定义, model_restore.py 文件无法识别。 当我们通过Tensor调用name变量时,我们应该使用get_tensor_by_name,因为此帖Tensorflow: How to get a tensor by name?建议。

<强> model_restore.py

import tensorflow as tf

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('my-model.meta')
  saver.restore(sess,'my-model')
  print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))
  # [ 0.43350926  1.02784836]
  print(tf.global_variables()) # print tensor variables
  # [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>,
  #  <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>]
  for op in tf.get_default_graph().get_operations():
    print str(op.name) # print all the operation nodes' name