Tensorflow saver.restore()不恢复网络

时间:2017-06-05 21:57:16

标签: python-2.7 tensorflow deep-learning

我完全迷失了张量流保护程序。

我试图遵循基本张量流深度神经网络模型教程。我想弄清楚如何训练网络几次迭代,然后在另一个会话中加载模型。

ClassNotFoundException

跳过训练。

with tf.Session() as sess:
    graph = tf.Graph()
    x = tf.placeholder(tf.float32,shape=[None,784])
    y_ = tf.placeholder(tf.float32, shape=[None,10])

    sess.run(global_variables_initializer())

    #Define the Network
    #(This part is all copied from the tutorial - not copied for brevity)
    #See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/

控制台打印出来:

  
    

步骤0,训练精度0.16

         

测试精度0.0719

         

步骤100,训练准确度0.88

         

测试准确度0.8734

  

接下来我想加载模型

    #Train the Network
    train_step = tf.train.AdamOptimizer(1e-4).minimize(
                     cross_entropy,global_step=global_step)
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    saver = tf.train.Saver()

    for i in range(101):
        batch = mnist.train.next_batch(50)
        if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict=
                           {x:batch[0],y_:batch[1]})
        print 'Step %d, training accuracy %g'%(i,train_accuracy)
            train_step.run(feed_dict={x:batch[0], y_: batch[1]})
        if i%100 == 0:
            print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
                       mnist.test.images, y_: mnist.test.labels})

        saver.save(sess,'./mnist_model')

现在我想重新测试以查看模型是否已加载

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('mnist_model.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    sess.run(tf.global_variables_initializer())

控制台打印出来:

  
    

测试精度0.1151

  

模型是否正在保存任何数据?我做错了什么?

1 个答案:

答案 0 :(得分:4)

保存模型时,通常所有全局变量都保存在外部文件中,而局部变量则不保存。您可以查看此answer以了解其中的差异。

恢复代码中的错误是在 tf.global_variable_initializer()之后调用saver.restore() saver.restore文档提到了

  

要恢复的变量不必初始化,因为恢复本身就是一种初始化变量的方法。

因此,请尝试删除该行,

sess.run(tf.global_variables_initializer())

理想情况下,您应该用

替换它
sess.run(tf.local_variables_initializer())