替换已保存图形的输入(数据集迭代器的占位符)后,如何从训练后的模型中恢复权重?

时间:2019-04-11 08:12:28

标签: python-3.x tensorflow tensorflow-datasets

我正在尝试恢复使用占位符输入构建的已保存模型,将占位符替换为tf.Dataset并重新训练模型。

我按照此处的说明进行操作: How to replace the input of a saved graph, e.g. a placeholder by a Dataset iterator?

但是我收到一个错误,即从导入的图形def中的变量未初始化。

所以问题是:用新的输入映射导入图def之后,如何将变量从原始图还原到新图中?

# Create simple graph:
x = tf.placeholder(dtype=tf.int64, shape=[1], name='x')
v1 = tf.get_variable("v1", shape=[1], initializer = 
tf.zeros_initializer, dtype=tf.int64)
add = v1 + x
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
   sess.run(init_op)
   sess.run(inc_v1)
   v1_res = sess.run(v1)
   res = sess.run(add, feed_dict={x: [4]})
   print("res:", res)
   print("v1: ", v1_res)
   saver.save(sess, "/tmp/switch.ckpt")

graph_def = tf.get_default_graph().as_graph_def()
tf.reset_default_graph()


batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, 
return_elements=['add:0'])


with tf.Session() as sess:
  print(sess.run(y))

我期望y的输出为1,但出现错误 图中的变量未初始化:

FailedPreconditionError: Attempting to use uninitialized value 
import/v1
 [[node import/v1/read (defined at <ipython-input-20-5de7b2bcb219>:26) ]]

1 个答案:

答案 0 :(得分:0)

实际上,您只是input_mapreturn_elements中的张量名称上犯了一个错误。 有两个技巧:

  1. 每次操作时都应命名,并使用tf.name_scope来方便研究名称。
  2. 您可以通过以下步骤列出所有操作名称: print([n.name for n in tf.get_default_graph().as_graph_def().node])

以下是已更正的代码:

# Create simple graph:
with tf.name_scope('graph'):
   x = tf.placeholder(dtype=tf.int64, shape=[1], name='x')
   v1 = tf.get_variable("v1", shape=[1], initializer=tf.zeros_initializer, dtype=tf.int64)
   y = tf.add(x, v1, name='AAdd')  # just to make sure that you can find the name
   inc_v1 = v1.assign(v1+1)
   init_op = tf.global_variables_initializer()
   saver = tf.train.Saver()

...

graph_def = tf.get_default_graph().as_graph_def() #graph_def is a language stubs

print('\n### check nodes')
for n in tf.get_default_graph().as_graph_def().node:
      print(n.name)
tf.reset_default_graph()

batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()

# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'graph/x:0': batch}, return_elements=['graph/add/y:0'])

# to load your ./tmp/switch.ckpt you should use tf.train.import_meta_graph()
# like in the second block of the answer https://stackoverflow.com/questions/50364377/how-to-replace-the-input-of-a-saved-graph-e-g-a-placeholder-by-a-dataset-itera

with tf.Session() as sess:
   for i in range(10):
     print(sess.run(y))