我正在尝试使用vgg_16微调 Momentum Optimizer 模型。为此,我使用了来自here的预训练模型。
在微调之前,我从模型中分配如下的变量值,
variables_to_restore = slim.get_variables_to_restore(exclude=["vgg_16/fc8"])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(model_path, variables_to_restore)
注意,我不排除vgg_16/*/*/Momentum
个变量。因此我收到了一个错误,
ValueError: Checkpoint is missing variable [vgg_16/conv1/conv1_1/weights/Momentum],
正如所料。
我的问题是,在exlude列表中包含所有Momentum变量非常麻烦(example)。有没有更聪明的方法来排除Momentum变量?
这很重要,因为对于像resnet这样的大型模型来说,手动输入排除项是不可能的。
提前谢谢!
答案 0 :(得分:1)
您可以使用以下代码解决此问题:
def _init_fn():
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=FLAGS.ignore_missing_vars)
在slim.learning.train(init_fn=init_fn,)