我正在使用TF-Slim从预训练模型中微调我的模型。当我使用create_train_op
时,我发现它的参数为variables_to_train
。在一些教程中,它使用了如下标志:
all_trainable = [v for v in tf.trainable_variables()]
trainable = [v for v in all_trainable]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
variables_to_train=trainable,
summarize_gradients=True)
但在官方TF-Slim中,它没有使用
all_trainable = [v for v in tf.trainable_variables()]
trainable = [v for v in all_trainable]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
summarize_gradients=True)
那么,使用和不使用variables_to_train
之间有什么不同?
答案 0 :(得分:2)
您的两个示例都执行相同的操作。您训练图中出现的所有可训练变量。使用参数variables_to_train
,您可以定义在训练期间应更新哪些变量。
一个用例是,当您已经预先训练了诸如词嵌入之类的不想在模型中训练的东西时。与
train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
variables_to_train=train_vars,
summarize_gradients=True)
您可以从名称中包含"embeddings"
的训练中排除所有变量。如果只想训练所有变量,则不必定义train_vars
,并且可以在不使用参数variables_to_train
的情况下创建训练操作。