我是TensorFlow的新手,所以即使这个问题完全是胡说八道,也请跟我一起玩......
我有一个代码
1)定义网络
x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
y_ = tf.placeholder(tf.float32, shape=[None, 10], name='reference')
...
fc_b2_hist = tf.summary.histogram('b_fc2', b_fc2)
2)然后用
恢复模型with tf.Session() as sess:
#NOTE
#sess.run(tf.initialize_all_variables())
model_path = tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph(model_path+'.meta')
saver.restore(sess, model_path)
all_vars = tf.trainable_variables()
for v in all_vars:
print(sess.run(v))
这个恢复模型的代码在单独的文件中运行时效果很好。 但是,在此运行时,它会中止并显示以下错误消息
回溯(最近一次呼叫最后一次):文件" lenet_my.py",第160行, print(sess.run(v))File" /usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", 第766行,在运行中 run_metadata_ptr)File" /usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", 第964行,在_run feed_dict_string,options,run_metadata)File" /usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", 第1014行,在_do_run中 target_list,options,run_metadata)File" /usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", 第1034行,在_do_call中 raise type(e)(node_def,op,message)tensorflow.python.framework.errors_impl.FailedPreconditionError: 试图使用未初始化的值lenet_model / conv_pool_1 / W_conv1 [[Node:_send_lenet_model / conv_pool_1 / W_conv1_0 = _SendT = DT_FLOAT,client_terminated = true,recv_device =" / job:localhost / replica:0 / task:0 / cpu:0", send_device =" /作业:本地主机/复制:0 /任务:0 / CPU:0&#34 ;, send_device_incarnation = 422131278131772803, tensor_name =" lenet_model / conv_pool_1 / W_conv1:0&#34 ;, _device =" /作业:本地主机/复制:0 /任务:0 / CPU:0"]]
在我第一次看到这条消息之后,我在#NOTE下取消了这条线,这是
sess.run(tf.initialize_all_variables())
它没有显示这样的错误,但是预训练的变量没有恢复,并且是在定义网络时如何定义的。
所以我有两个问题!
首先,我不知道在单独的文件中运行代码和在一个文件中运行它以获得此类HORRIFYING错误消息之间的区别 其次,我不明白为什么初始化变量然后使用上面编写的代码恢复模型不会恢复以前训练过的变量。
Thnx提前
答案 0 :(得分:0)
我认为 不能 运行tf.train.import_meta_graph()
可能会有所帮助。
从.meta
文件的导入将创建该文件中指定的新图形,您只需构建自己的图形就不需要此新图形。
只要说:
saver = tf.train.Saver()
with tf.Session() as sess:
model_path = tf.train.latest_checkpoint(model_path)
saver.restore(sess, model_path)
答案 1 :(得分:-1)
也许你应该把
saver=tf.train.import_meta_graph(model_path+'.meta')
超出你的“会话”
以下是我的代码:
saver = tf.train.import_meta_graph('./models/xxx.ckpt-30000.meta')
with tf.Session() as sess:
saver.restore(sess,'./models/xxx.ckpt-30000')
希望它有用