在检查点中找不到TensorFlow恢复模型密钥变量

时间:2018-11-08 09:27:51

标签: python tensorflow

OSX py3.6 tf1.8

我是张量流初学者。我尝试为mnist训练模型。还原模型时出现错误。

from datetime import datetime

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
save_model_path = 'mnist_model/model.ckpt'


def train():
    learning_rate = 0.05
    batch_size = 100
    max_epochs = 100
    num_of_batch = int(mnist.train.num_examples / batch_size)
    now = datetime.utcnow().strftime("%Y%m%d%H%M%S")

    X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
    y = tf.placeholder(tf.float32, shape=[None, 10], name='y')
    print(X.name, y.name)

    W = tf.get_variable(shape=[784, 10], name='weight')
    b = tf.get_variable(initializer=tf.zeros([10]), name='bais')
    tf.summary.histogram("weights", W)
    tf.summary.histogram("biases", b)

    with tf.name_scope('pred'):
        y_pred = tf.nn.softmax(tf.matmul(X, W) + b, name='predict')
        print(y_pred.name)

    with tf.name_scope('loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
        tf.summary.scalar('loss', loss)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

    with tf.name_scope('acc'):
        correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='acc')
        print(accuracy.name)

    merged_summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)

        loss_avg = 0
        writer = tf.summary.FileWriter('mnist/{}'.format(now), sess.graph)
        for epoch in range(max_epochs):
            for i in range(num_of_batch):
                batch_x, batch_y = mnist.train.next_batch(batch_size)
                summary_str, _, l = sess.run([merged_summary_op, optimizer, loss], feed_dict={X: batch_x, y: batch_y})
                loss_avg += l
                global_step = epoch * num_of_batch + i
                writer.add_summary(summary_str, global_step)

                if global_step % 100 == 0:
                    print('Epoch {}: {} save model'.format(epoch, i))
                    # save model in halfway
                    saver.save(sess, save_model_path, global_step=global_step)

            loss_avg /= num_of_batch
            print('Epoch {}: Loss {}'.format(epoch, loss_avg))

        print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))
        saver.save(sess, save_model_path)


def predict(import_from_meta=False):
    if import_from_meta:
        meta_path = 'mnist_model/model.ckpt.meta'
        checkpoint_path = 'mnist_model'
    else:
        # stupid var WTF ValueError: No variables to save
        _ = tf.Variable(0)
        saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if import_from_meta:
            saver = tf.train.import_meta_graph(meta_path)
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))
        else:
            saver.restore(sess, save_model_path)
        graph = tf.get_default_graph()
        X = graph.get_tensor_by_name('X:0')
        y = graph.get_tensor_by_name('y:0')
        accuracy = graph.get_tensor_by_name('acc/acc:0')
        print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))

        pred = graph.get_tensor_by_name('pred/predict:0')
        import matplotlib.pyplot as plt
        i = 90
        img_orign = mnist.train.images[i]
        img = img_orign.reshape((28, 28))
        plt.imshow(img, cmap='gray')
        plt.title(mnist.train.labels[i])
        plt.show()
        a = sess.run(pred, feed_dict={X: img_orign.reshape(-1, 784)})
        print(a.shape)
        import numpy as np
        print(np.argmax(a))


def check_ckpt():
    from tensorflow.python.tools import inspect_checkpoint as chkp
    chkp.print_tensors_in_checkpoint_file(save_model_path, tensor_name='', all_tensors=True)


if __name__ == '__main__':
    # train()
    predict(import_from_meta=False)
    # check_ckpt()

使用predict(import_from_meta=False)

错误:

WARNING:tensorflow:From /Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2018-11-08 16:53:40.482921: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Not found: Key Variable not found in checkpoint
Traceback (most recent call last):
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call
    return fn(*args)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 115, in <module>
    predict(import_from_meta=False)
  File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 92, in predict
    saver.restore(sess, save_model_path)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1802, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 115, in <module>
    predict(import_from_meta=False)
  File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 84, in predict
    saver = tf.train.Saver()
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1338, in __init__
    self.build()
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1347, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1384, in _build
    build_save=build_save, build_restore=build_restore)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 835, in _build_internal
    restore_sequentially, reshape)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 472, in _AddRestoreOps
    restore_sequentially)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 886, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Key Variable not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

很奇怪,当我使用predict(import_from_meta=True)时,我会得到正确答案

然后,我尝试从check_ckpt()inspect variables in a checkpoint。我找不到合适的tensor_name。真好笑,name_scope就像个玩笑。

tensor_name:  bais
[-24.933702    1.7660792   3.697866  -14.221888    8.967291   42.149403
  -3.2693458  23.876926  -30.643892   -3.7861202]
tensor_name:  bais/Adam
[ 3.1726879e-07 -5.2043208e-07  3.4227469e-05  2.5119303e-07
 -2.0110610e-04  1.8493415e-04 -3.6275055e-06 -1.4343520e-04
 -7.2765622e-05  2.0172486e-04]
tensor_name:  bais/Adam_1
[5.2586905e-08 8.9204484e-08 1.5440051e-07 2.9412612e-07 2.4380788e-07
 3.4676964e-07 8.7062219e-08 1.8839150e-07 4.3878950e-07 4.2466107e-07]
tensor_name:  loss/beta1_power
0.0
tensor_name:  loss/beta2_power
1.2639432e-24
tensor_name:  weight
[[-0.03386476  0.03485525 -0.03267809 ... -0.08548199  0.00565728
  -0.01887459]
 [ 0.00370622  0.08523928  0.05811391 ... -0.07838921  0.05987743
   0.074329  ]
 [ 0.0180116   0.04400793 -0.0260816  ...  0.00807328  0.06537797
  -0.07446742]
 ...
 [-0.00665552 -0.03390152 -0.03889231 ... -0.01871967 -0.05968629
   0.07207178]
 [ 0.01317277  0.03459686 -0.03268962 ...  0.07082433  0.03290742
   0.03172391]
 [-0.04514085 -0.03013236  0.01006595 ...  0.01906221  0.02611361
   0.04348358]]
tensor_name:  weight/Adam
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
tensor_name:  weight/Adam_1
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

Process finished with exit code 0

那么我的代码有什么问题?为什么只想恢复模型,为什么必须在tf.train.Saver之前创建变量?

0 个答案:

没有答案