我已经建立了CNN用于图像分类。在训练期间,我保存了几个检查站。数据通过feed_dictionary输入网络。
现在我想恢复失败的模型,我无法弄清楚原因。重要的代码行如下:
with tf.Graph().as_default():
....
if checkpoint_dir is not None:
checkpoint_saver = tf.train.Saver()
session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir,
save_secs=flags.save_interval_secs,
saver=checkpoint_saver))
....
with tf.train.MonitoredTrainingSession(
save_summaries_steps=flags.save_summaries_steps,
hooks=session_hooks,
config=tf.ConfigProto(
log_device_placement=flags.log_device_placement)) as mon_sess:
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint and checkpoint.model_checkpoint_path:
# restoring from the checkpoint file
checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path)
global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1]
print("Model restored from checkpoint: global_step = %s" % global_step_restore)
Line“checkpoint_saver.restore”抛出错误:
追踪(最近一次通话): 文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,第1022行,在_do_call中 return fn(* args) 文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,第1004行,在_run_fn中 status,run_metadata) 文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ contextlib.py”,第66行,退出 下一个(self.gen) 在raise_exception_on_not_ok_status中输入文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ framework \ errors_impl.py”,第469行 pywrap_tensorflow.TF_GetCode(状态)) tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须使用dtype float为占位符张量'input_images'提供值 [[Node:input_images = Placeholderdtype = DT_FLOAT,shape = [],_ device =“/ job:localhost / replica:0 / task:0 / cpu:0”]]
任何人都知道如何解决这个问题?为什么我只需要一个填充的feed_dictionary来恢复图形?
提前致谢!
更新
这是保护程序对象的恢复方法的代码:
def restore(self, sess, save_path):
"""Restores previously saved variables.
This method runs the ops added by the constructor for restoring variables.
It requires a session in which the graph was launched. The variables to
restore do not have to have been initialized, as restoring is itself a way
to initialize variables.
The `save_path` argument is typically a value previously returned from a
`save()` call, or a call to `latest_checkpoint()`.
Args:
sess: A `Session` to use to restore the parameters.
save_path: Path where parameters were previously saved.
"""
if self._is_empty:
return
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
我没有得到:为什么图表会立即执行?我使用了错误的方法吗?我只想恢复所有可训练的变种。
答案 0 :(得分:1)
问题是由SessionRunHook引起的进程日志记录:
原来的钩子:class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
self._start_time = time.time()
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
duration = time.time() - self._start_time
loss_value = run_values.results
if self._step % 5 == 0:
num_examples_per_step = FLAGS.batch_size
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
修改了钩子:
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, flags, loss_op):
self._flags = flags
self._loss_op = loss_op
self._start_time = time.time()
def begin(self):
self._step = 0
def before_run(self, run_context):
if self._step == 0:
run_args = None
else:
run_args = tf.train.SessionRunArgs(self._loss_op)
return run_args
def after_run(self, run_context, run_values):
if self._step > 0:
duration_n_steps = time.time() - self._start_time
loss_value = run_values.results
if self._step % self._flags.log_every_n_steps == 0:
num_examples_per_step = self._flags.batch_size
duration = duration_n_steps / self._flags.log_every_n_steps
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
self._start_time = time.time()
self._step += 1
说明:
第一次迭代现在记录了日志。所以由Saver.restore(..)执行的session.run已经不再需要填充的供稿字典。