我正在TF中建立一个CNN模型。
我保存了很少的变量
wc1 = tf.Variable(tf.random_normal([5, 5, 1, 32]), name='wc1')
wc2 = tf.Variable(tf.random_normal([5, 5, 32, 64]), name='wc2')
通过
saver = tf.train.Saver([wc1, wc2])
saver.save(sess, './cnn_model')
当我在另一个会话中恢复已保存的模型并使用tf.Print()打印张量时,无法打印它。波纹管代码用于恢复模型
sess = tf.Session()
saver = tf.train.import_meta_graph("./cnn_model.meta")
saver.restore(sess, './cnn_model')
wc1 = tf.get_default_graph().get_tensor_by_name("wc1:0")
wc2 = tf.get_default_graph().get_tensor_by_name("wc2:0")
while some_step:
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
wc1 = tf.Print(wc1, [wc1], 'WC1 is: ')
如何为我保存的模型打印/获取张量值?
答案 0 :(得分:3)
您只需执行sess.run(wc1)
即可获取模型的值。请参阅下面的代码示例:
>>> import tensorflow as tf
>>> wc1 = tf.Variable(tf.random_normal([5, 5, 1, 32]), name='wc1')
>>> saver = tf.train.Saver([wc1])
>>> with tf.Session('') as sess:
... tf.global_variables_initializer().run(session=sess)
... saver.save(sess, './cnn_model')
'./cnn_model'
>>> sess = tf.Session('')
>>> saver = tf.train.import_meta_graph("./cnn_model.meta")
>>> saver.restore(sess, './cnn_model')
INFO:tensorflow:Restoring parameters from ./cnn_model
>>> wc_r1 = tf.get_default_graph().get_tensor_by_name('wc1:0')
>>> sess.run(wc_r1)
array([[[[ 0.82639563, -0.33938187, 0.26812711, -0.32433796, 1.2584244 ,
-0.25379655, -0.16618967, 0.27060306, 1.53495347, 0.75791109,
-0.87073582, 1.48225808, 1.13401747, -1.80606318, 1.0940119 ,
0.52464408, -0.24058162, -1.36783814, -0.04032131, 0.82713342,
1.32288456, -1.32494891, 0.93615007, -0.74220407, 1.13950729,
0.39443189, 1.81868839, 0.91872966, 1.73204434, -1.26066136,
-1.12299716, -1.26222265]],
...