我已经使用这些代码成功恢复了一部分预训练模型:
network = importlib.import_module(args.model_def)
exclude = ['InceptionResnetV1/Bottleneck/BatchNorm/beta']
include = [v.name for v in tf.trainable_variables()]
variables_to_restore = slim.get_variables_to_restore(include=include, exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, pretrained_model_dir)
那我只想冻结这些恢复的层,该怎么办?
因为我有了这些张量的列表,所以我认为有两种方法(但不知道如何编码):
(1)set those tensor in the 'variables_to_restore' list to constant
(2)set those tensor in the 'variables_to_restore' list to 'untrainable'