Tf-slim中的variables_to_train标志

时间:2018-02-01 08:20:37

标签: tensorflow deep-learning tf-slim

我正在使用TF-Slim从预训练模型中微调我的模型。当我使用create_train_op时,我发现它的参数为variables_to_train。在一些教程中,它使用了如下标志:

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,
        variables_to_train=trainable,
        summarize_gradients=True)

但在官方TF-Slim中,它没有使用

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,            
        summarize_gradients=True)

那么,使用和不使用variables_to_train之间有什么不同?

1 个答案:

答案 0 :(得分:2)

您的两个示例都执行相同的操作。您训练图中出现的所有可训练变量。使用参数variables_to_train,您可以定义在训练期间应更新哪些变量。

一个用例是,当您已经预先训练了诸如词嵌入之类的不想在模型中训练的东西时。与

train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op      = slim.learning.create_train_op(
    opt,
    global_step=global_step,
    variables_to_train=train_vars,
    summarize_gradients=True)

您可以从名称中包含"embeddings"的训练中排除所有变量。如果只想训练所有变量,则不必定义train_vars,并且可以在不使用参数variables_to_train的情况下创建训练操作。