Tensorflow抱怨图恢复期间缺少feed_dict

时间:2017-03-02 19:12:52

标签: tensorflow feed restore

我已经建立了CNN用于图像分类。在训练期间,我保存了几个检查站。数据通过feed_dictionary输入网络。

现在我想恢复失败的模型,我无法弄清楚原因。重要的代码行如下:

with tf.Graph().as_default():

....

if checkpoint_dir is not None:
    checkpoint_saver = tf.train.Saver()
    session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir,
                                                      save_secs=flags.save_interval_secs,
                                                      saver=checkpoint_saver))
....

with tf.train.MonitoredTrainingSession(
        save_summaries_steps=flags.save_summaries_steps,
        hooks=session_hooks,
        config=tf.ConfigProto(
            log_device_placement=flags.log_device_placement)) as mon_sess:

    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:

        # restoring from the checkpoint file
        checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path)

        global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1]
        print("Model restored from checkpoint: global_step = %s" % global_step_restore)

Line“checkpoint_saver.restore”抛出错误:

追踪(最近一次通话):   文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,第1022行,在_do_call中     return fn(* args)   文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,第1004行,在_run_fn中     status,run_metadata)   文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ contextlib.py”,第66行,退出     下一个(self.gen)   在raise_exception_on_not_ok_status中输入文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ framework \ errors_impl.py”,第469行     pywrap_tensorflow.TF_GetCode(状态)) tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须使用dtype float为占位符张量'input_images'提供值      [[Node:input_images = Placeholderdtype = DT_FLOAT,shape = [],_ device =“/ job:localhost / replica:0 / task:0 / cpu:0”]]

任何人都知道如何解决这个问题?为什么我只需要一个填充的feed_dictionary来恢复图形?

提前致谢!

更新

这是保护程序对象的恢复方法的代码:

  def restore(self, sess, save_path):
    """Restores previously saved variables.

    This method runs the ops added by the constructor for restoring variables.
    It requires a session in which the graph was launched.  The variables to
    restore do not have to have been initialized, as restoring is itself a way
    to initialize variables.

    The `save_path` argument is typically a value previously returned from a
    `save()` call, or a call to `latest_checkpoint()`.

    Args:
      sess: A `Session` to use to restore the parameters.
      save_path: Path where parameters were previously saved.
    """
    if self._is_empty:
      return
    sess.run(self.saver_def.restore_op_name,
             {self.saver_def.filename_tensor_name: save_path})

我没有得到:为什么图表会立即执行?我使用了错误的方法吗?我只想恢复所有可训练的变种。

1 个答案:

答案 0 :(得分:1)

问题是由SessionRunHook引起的进程日志记录:

原来的钩子:

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1

  def before_run(self, run_context):
    self._step += 1
    self._start_time = time.time()
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    duration = time.time() - self._start_time
    loss_value = run_values.results
    if self._step % 5 == 0:
      num_examples_per_step = FLAGS.batch_size
      examples_per_sec = num_examples_per_step / duration
      sec_per_batch = float(duration)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

修改了钩子:

class _LoggerHook(tf.train.SessionRunHook):
    """Logs loss and runtime."""

    def __init__(self, flags, loss_op):
        self._flags = flags
        self._loss_op = loss_op
        self._start_time = time.time()

    def begin(self):
        self._step = 0

    def before_run(self, run_context):
        if self._step == 0:
            run_args = None
        else:
            run_args = tf.train.SessionRunArgs(self._loss_op)

        return run_args

    def after_run(self, run_context, run_values):

        if self._step > 0:
            duration_n_steps = time.time() - self._start_time
            loss_value = run_values.results
            if self._step % self._flags.log_every_n_steps == 0:
                num_examples_per_step = self._flags.batch_size

                duration = duration_n_steps / self._flags.log_every_n_steps
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), self._step, loss_value,
                                    examples_per_sec, sec_per_batch))

                self._start_time = time.time()
        self._step += 1

说明:

第一次迭代现在记录了日志。所以由Saver.restore(..)执行的session.run已经不再需要填充的供稿字典。