Tensorflow:改进从检查点加载的错误预测的Tensorflow模型

时间:2017-11-06 16:26:59

标签: python tensorflow

在训练网络并输出模型后,是否可以在从检查点加载第一个模型后将单个错误预测提供给网络,以帮助提高性能。

目前我尝试过的解决方案涉及:

    with tf.Session() as sess:


        #Try new method...
        saver = tf.train.import_meta_graph('./model2.ckpt.meta')
        saver.restore(sess, './model2.ckpt')
        graph = tf.get_default_graph()

        x_1 = graph.get_tensor_by_name("x:0")
        y_1 = graph.get_tensor_by_name("y_:0")
        train_step = graph.get_tensor_by_name("train_step:0")

        trX = image
        trY = np.eye(10)[digit]

        sess.run(train_step, feed_dict={x: trX, y_: [trY],
                            keep_prob:0.5})

        new_saver = tf.train.Saver()
        new_saver.save(sess, "./model2.ckpt")
        sess.close()

        trX = image
        trY = np.eye(10)[digit]

        sess.run(train_step, feed_dict={x: trX, y_: [trY],
                            keep_prob:0.5})

或者这个解决方案:

    with tf.Session() as sess:

        saver = tf.train.Saver()
        saver.restore(sess, "./model2.ckpt")

        trX = image
        trY = np.eye(10)[digit]

        sess.run(train_step, feed_dict={x: trX, y_: [trY],
                            keep_prob:0.5})

两者都导致网络失去了正确预测先前图像的能力,这是正确预测的。

1 个答案:

答案 0 :(得分:0)

事实证明,问题在于训练步骤,而不是重量的恢复。在加载权重并运行训练步骤之后,权重对于每个值显示NAN。通过剪切交叉熵函数中的值来解决此问题。

原始的交叉熵函数是: cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

哪个被改成: cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y_conv,1e-10,1.0)))