tensorflow:从检查点恢复以继续训练

时间:2016-09-02 13:55:14

标签: tensorflow restore checkpointing

在这种情况下,我想继续从checkpoint.i训练我的模型使用cifar-10示例并在cifar-10_train.py中做了一点改变,如下所示,它们几乎相同,除了我想要恢复检查点: 我用md替换了cifar-10。

"""

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os.path
import time
import numpy

import tensorflow.python.platform
from tensorflow.python.platform import gfile

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

import md
"""



"""


FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/root/test/INT/tbc',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 60000,        # 55000 steps per epoch
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '/root/test/INT/',
                           """If specified, restore this pretrained model """
                           """before beginning any training.""")





def error_rate(predictions, labels):
  """Return the error rate based on dense predictions and 1-hot labels."""
  return 100.0 - (
      100.0 *
      numpy.sum(numpy.argmax(predictions, 0) == numpy.argmax(labels, 0)) /
      predictions.shape[0])






def train():
  """Train MD65500 for a number of steps."""
  with tf.Graph().as_default():
    # global_step = tf.Variable(0, trainable=False)

    global_step = tf.get_variable(
        'global_step', [],
        initializer=tf.constant_initializer(0), trainable=False)






    # Get images and labels for CIFAR-10.
    images, labels = md.distorted_inputs()


    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = md.inference(images)

    # Calculate loss.
    loss = md.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = md.train(loss, global_step)

    # Predictions for the minibatch. there is no validation set or test set.
    # train_prediction = tf.nn.softmax(logits)
    train_prediction = logits

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    # sess = tf.Session(config=tf.ConfigProto(
        # log_device_placement=FLAGS.log_device_placement))
    # sess.run(init)

    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    # sess.run(init)


    if FLAGS.pretrained_model_checkpoint_path:
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      # variables_to_restore = tf.get_collection(
          # slim.variables.VARIABLES_TO_RESTORE)


      variable_averages = tf.train.ExponentialMovingAverage(
          md.MOVING_AVERAGE_DECAY)
      variables_to_restore = {}
      for v in tf.all_variables():
        if v in tf.trainable_variables():
            restore_name = variable_averages.average_name(v)
        else:
            restore_name = v.op.name
        variables_to_restore[restore_name] = v


      ckpt = tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path)
      if ckpt and ckpt.model_checkpoint_path:

        # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, ckpt.model_checkpoint_path)
        print('%s: Pre-trained model restored from %s' %
              (datetime.now(), ckpt.model_checkpoint_path))
        # print("variables_to_restore")
        # print(variables_to_restore)

      else:
        sess.run(init)










    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph)       #####graph_def=sess.graph_def)

    # tf.add_to_collection('train_op', train_op)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value, predictions = sess.run([train_op, loss, train_prediction])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 100 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))
        # print('Minibatch error: %.5f%%' % error_rate(predictions, labels))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)


def main(argv=None):  # pylint: disable=unused-argument
  # md.maybe_download()
  # if gfile.Exists(FLAGS.train_dir):
    # gfile.DeleteRecursively(FLAGS.train_dir)
  # gfile.MakeDirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
  tf.app.run()

当我运行代码时,错误如下:

[root@bogon md try]# pythonnew mdtbc_3.py
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcublas.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcudnn.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcufft.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcurand.so locally
Filling queue with 4000 CIFAR images before starting to train. This will take a few minutes.
I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:900] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties: 
name: GeForce GTX 980 Ti
major: 5 minor: 2 memoryClockRate (GHz) 1.228
pciBusID 0000:01:00.0
Total memory: 6.00GiB
Free memory: 5.78GiB
I tensorflow/core/common_runtime/gpu/gpu_init.cc:126] DMA: 0 
I tensorflow/core/common_runtime/gpu/gpu_init.cc:136] 0:     Y 
I tensorflow/core/common_runtime/gpu/gpu_device.cc:755] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 980 Ti, pci bus id: 0000:01:00.0)
2016-08-30 17:12:48.883303: Pre-trained model restored from /root/test/INT/model.ckpt-59999
WARNING:tensorflow:When passing a `Graph` object, please use the `graph` named argument instead of `graph_def`.
Traceback (most recent call last):
    File "mdtbc_3.py", line 195, in <module>
        tf.app.run()
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
        sys.exit(main(sys.argv))
    File "mdtbc_3.py", line 191, in main
        train()
    File "mdtbc_3.py", line 160, in train
        _, loss_value, predictions = sess.run([train_op, loss, train_prediction])
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 340, in run
        run_metadata_ptr)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 564, in _run
        feed_dict_string, options, run_metadata)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 637, in _do_run
        target_list, options, run_metadata)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 659, in _do_call
        e.code)
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value conv2/weights
     [[Node: conv2/weights/read = Identity[T=DT_FLOAT, _class=["loc:@conv2/weights"], _device="/job:localhost/replica:0/task:0/cpu:0"](conv2/weights)]]
Caused by op u'conv2/weights/read', defined at:
    File "mdtbc_3.py", line 195, in <module>
        tf.app.run()
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
        sys.exit(main(sys.argv))
    File "mdtbc_3.py", line 191, in main
        train()
    File "mdtbc_3.py", line 77, in train
        logits = md.inference(images)
    File "/root/test/md try/md.py", line 272, in inference
        stddev=0.1, wd=0.0)
    File "/root/test/md try/md.py", line 114, in _variable_with_weight_decay
        tf.truncated_normal_initializer(stddev=stddev))
    File "/root/test/md try/md.py", line 93, in _variable_on_cpu
        var = tf.get_variable(name, shape, initializer=initializer)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 339, in get_variable
        collections=collections)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 262, in get_variable
        collections=collections, caching_device=caching_device)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 158, in get_variable
        dtype=variable_dtype)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 209, in __init__
        dtype=dtype)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 318, in _init_from_args
        self._snapshot = array_ops.identity(self._variable, name="read")
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 609, in identity
        return _op_def_lib.apply_op("Identity", input=input, name=name)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
        op_def=op_def)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2154, in create_op
        original_op=self._default_original_op, op_def=op_def)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1154, in __init__
        self._traceback = _extract_stack()

当我取消注释第107行&#34; sess.run(init)&#34;它完美运行,但是初始化的模型,只是来自sctrach的新模型。我想从检查点恢复变量,并继续我的训练。我想恢复。

1 个答案:

答案 0 :(得分:3)

如果没有其他代码方便,我会说以下部分存在问题:

for v in tf.all_variables():
    if v in tf.trainable_variables():
        restore_name = variable_averages.average_name(v)
    else:
        restore_name = v.op.name
    variables_to_restore[restore_name] = v

因为您在此处指定要恢复的变量列表,但是您排除了一些变量(即可训练变量中的变量的v.op.name)。这将改变抛出错误的网络中的变量名称(再次,没有其余的代码,我真的不能说),s.t。一个(或多个)变量未正确恢复。两种方法(不是很复杂)可以帮助你:

  1. 如果不存储所有变量,请先进行初始化,然后恢复实际存储的变量。这确保你真正不关心的张量得到了初始化
  2. TF在存储网络方面非常有效。如果有疑问,请存储所有变量......