我正在尝试恢复使用占位符输入构建的已保存模型,将占位符替换为tf.Dataset并重新训练模型。
我按照此处的说明进行操作: How to replace the input of a saved graph, e.g. a placeholder by a Dataset iterator?
但是我收到一个错误,即从导入的图形def中的变量未初始化。
所以问题是:用新的输入映射导入图def之后,如何将变量从原始图还原到新图中?
# Create simple graph:
x = tf.placeholder(dtype=tf.int64, shape=[1], name='x')
v1 = tf.get_variable("v1", shape=[1], initializer =
tf.zeros_initializer, dtype=tf.int64)
add = v1 + x
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
sess.run(inc_v1)
v1_res = sess.run(v1)
res = sess.run(add, feed_dict={x: [4]})
print("res:", res)
print("v1: ", v1_res)
saver.save(sess, "/tmp/switch.ckpt")
graph_def = tf.get_default_graph().as_graph_def()
tf.reset_default_graph()
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch},
return_elements=['add:0'])
with tf.Session() as sess:
print(sess.run(y))
我期望y的输出为1,但出现错误 图中的变量未初始化:
FailedPreconditionError: Attempting to use uninitialized value
import/v1
[[node import/v1/read (defined at <ipython-input-20-5de7b2bcb219>:26) ]]
答案 0 :(得分:0)
实际上,您只是在input_map
和return_elements
中的张量名称上犯了一个错误。
有两个技巧:
tf.name_scope
来方便研究名称。print([n.name for n in tf.get_default_graph().as_graph_def().node])
以下是已更正的代码:
# Create simple graph:
with tf.name_scope('graph'):
x = tf.placeholder(dtype=tf.int64, shape=[1], name='x')
v1 = tf.get_variable("v1", shape=[1], initializer=tf.zeros_initializer, dtype=tf.int64)
y = tf.add(x, v1, name='AAdd') # just to make sure that you can find the name
inc_v1 = v1.assign(v1+1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
...
graph_def = tf.get_default_graph().as_graph_def() #graph_def is a language stubs
print('\n### check nodes')
for n in tf.get_default_graph().as_graph_def().node:
print(n.name)
tf.reset_default_graph()
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'graph/x:0': batch}, return_elements=['graph/add/y:0'])
# to load your ./tmp/switch.ckpt you should use tf.train.import_meta_graph()
# like in the second block of the answer https://stackoverflow.com/questions/50364377/how-to-replace-the-input-of-a-saved-graph-e-g-a-placeholder-by-a-dataset-itera
with tf.Session() as sess:
for i in range(10):
print(sess.run(y))