tensorflow如何保存和恢复模型?

时间:2016-12-30 13:44:37

标签: python model save tensorflow restore

首先,我知道如何使用tensorflow来保存和恢复模型,就像教程一样:

import tensorflow as tf
v1 = tf.Variable(1.32, name="v1")
v2 = tf.Variable(1.33, name="v2")

init = tf.initialize_all_variables()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print v2.eval(sess)
    save_path="model.ckpt"
    saver.save(sess,save_path)
    saver.restore(sess, save_path)
    print("Model restored.")

正如您所看到的,它运作良好。然后我在单个文件model_save.py中写了两个模型,就像这样:

# -*- coding: utf8 -*-

import tensorflow as tf 
v1 = tf.Variable(66, name="v1")
v2 = tf.Variable(77, name="v2")

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print 'v2', v2.eval(sess)
    save_path = "model_v/model.ckpt"
    saver.save(sess, save_path)
    print "ModelV saved."
    saver.restore(sess, save_path)
    print "ModelV restored."
    print 'v2', v2.eval(sess)

p1 = tf.Variable(88, name="p1")
p2 = tf.Variable(99, name="p2")

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print 'p2',p2.eval(sess)
    save_path = "model_p/model.ckpt"
    saver.save(sess, save_path)
    print "ModelP saved."
    saver.restore(sess, save_path)
    print "ModelP restored."
    print 'p2', p2.eval(sess)

效果也很好!然后我写了一个model.py文件,其中我将pv包装在两个模型中,这两个模型称为ModelPModelV,就像这样:

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

    def __init__(self):

        self.v1 = tf.Variable(66, name="v1")
        self.v2 = tf.Variable(77, name="v2")
        self.save_path = "model_v/model.ckpt"
        self.init = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
        self.sess = tf.Session()

    def train(self):
        self.sess.run(self.init)
        print 'v2', self.v2.eval(self.sess)

        self.saver.save(self.sess, self.save_path)
        print "ModelV saved."

    def predict(self):
        self.saver.restore(self.sess, self.save_path)
        print "ModelV restored."
        print 'v2', self.v2.eval(self.sess)

class ModelP():

    def __init__(self):

        self.p1 = tf.Variable(88, name="p1")
        self.p2 = tf.Variable(99, name="p2")
        self.save_path = "model_p/model.ckpt"
        self.init = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
        self.sess = tf.Session()

    def train(self):
        self.sess.run(self.init)
        print 'p2', self.p2.eval(self.sess)

        self.saver.save(self.sess, self.save_path)
        print "ModelP saved."

    def predict(self):
        self.saver.restore(self.sess, self.save_path)
        print "ModelP restored."
        print 'p2', self.p2.eval(self.sess)


if __name__ == '__main__':
    v = ModelV()
    p = ModelP()
    #v.train()
    v.predict()
    #p.train()
    p.predict()

然后我直接使用model_save.py保存的模型,它再次运行良好。 1 现在我将提出问题,我将ModelPModelV分成两个文件:

ModelV.py

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

    def __init__(self):

        self.v1 = tf.Variable(66, name="v1")
        self.v2 = tf.Variable(77, name="v2")
        self.save_path = "model_v/model.ckpt"
        self.init = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
        self.sess = tf.Session()

    def train(self):
        self.sess.run(self.init)
        print 'v2', self.v2.eval(self.sess)

        self.saver.save(self.sess, self.save_path)
        print "ModelV saved."

    def predict(self):
        self.saver.restore(self.sess, self.save_path)
        print "ModelV restored."
        print 'v2', self.v2.eval(self.sess)

if __name__ == '__main__':
    v = ModelV()
    v.train()
    v.predict()

ModelP.py

# -*- coding: utf-8 -*-

import tensorflow as tf

class ModelP():

    def __init__(self):

        self.p1 = tf.Variable(88, name="p1")
        self.p2 = tf.Variable(99, name="p2")
        self.save_path = "model_p/model.ckpt"
        self.init = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
        self.sess = tf.Session()

    def train(self):
        self.sess.run(self.init)
        print 'p2', self.p2.eval(self.sess)

        self.saver.save(self.sess, self.save_path)
        print "ModelP saved."

    def predict(self):
        self.saver.restore(self.sess, self.save_path)
        print "ModelP restored."
        print 'p2', self.p2.eval(self.sess)


if __name__ == '__main__':
    p = ModelP()
    p.train()
    p.predict()

然后我运行以下,首先是ModelV.py,然后是ModelP.py,然后是model.py; 2 看看发生了什么,错误发生了!以下是我的错误:

ModelV restored.
v2 77
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v2 not found in checkpoint
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v1 not found in checkpoint
Traceback (most recent call last):
File "model.py", line 58, in <module>
    p.predict()
File "model.py", line 47, in predict
    self.saver.restore(self.sess, self.save_path)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1388, in restore
    {self.saver_def.filename_tensor_name: save_path})
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 766, in run
    run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 964, in _run
    feed_dict_string, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1014, in _do_run
    target_list, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1034, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key v2 not found in checkpoint
    [[Node: save_1/RestoreV2_3 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/RestoreV2_3/tensor_names, save_1/RestoreV2_3/shape_and_slices)]]

Caused by op u'save_1/RestoreV2_3', defined at:
File "model.py", line 54, in <module>
    p = ModelP()
File "model.py", line 36, in __init__
    self.saver = tf.train.Saver()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1000, in __init__
    self.build()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1030, in build
    restore_sequentially=self._restore_sequentially)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 624, in build
    restore_sequentially, reshape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 361, in _AddRestoreOps
    tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 200, in restore_op
    [spec.tensor.dtype])[0])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_io_ops.py", line 441, in restore_v2
    dtypes=dtypes, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
    op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2240, in create_op
    original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1128, in __init__
    self._traceback = _extract_stack()

NotFoundError (see above for traceback): Key v2 not found in checkpoint
    [[Node: save_1/RestoreV2_3 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/RestoreV2_3/tensor_names, save_1/RestoreV2_3/shape_and_slices)]]

这只是我用来测试的一个例子,因为我有另一个项目,我在一个文件中写了两个模型,每个模型都是一个对象,每个都提供一个预测和训练界面,同样的错误发生了。同样的张量流只是在另一个模型的检查点文件中搜索第一个模型的变量。它混合起来,但它们有两种型号!谢谢你给我一些帮助!

0 个答案:

没有答案