在TF Estimator中冻结和解冻网络层

时间:2018-12-26 22:21:56

标签: python tensorflow tensorflow-estimator

我正在使用TF Estimator在数据集中训练我的模型。对于前几次训练迭代,我想冻结网络中的某些层。对于其余的迭代,我想解冻这些层。

我找到了一些解决方案,其中在估算器的model_fn中有两个不同的优化器train_ops。

def ModelFunction(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.TRAIN:
        layerTrainingVars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "LayerName")
        #Train Op for freezing layers
        freeze_train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step(), var_list=layerTrainingVars)
        #Train Op for training all layers
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        #Based on whether we want to freeze or not, we send the corresponding train_op to the estimatorSpec. How do I do this?
        estimatorSpec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=freeze_train_op)

    return estimatorSpec

对于上述解决方案,可以基于train_op返回相应的EstimatorSpec。我尝试使用freeze_train_op进行几次训练迭代,然后终止该过程,并更改train_op以使代码中没有冻结的层。完成此操作后,将出现一个检查点错误,该错误表明保存在检查点中的图形/变量不同。我猜第一组迭代没有保存冻结的图层。如何以编程方式切换train_ops,以便检查点也能正常工作?

是否有更好的方法来冻结/解冻用于TF.Estmator的训练层?

1 个答案:

答案 0 :(得分:0)

您可以将2个train_op分组在一起以返回它们。

def ModelFunction(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.TRAIN:
        layerTrainingVars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "LayerName")
        #Train Op for freezing layers
        freeze_train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step(), var_list=layerTrainingVars)
        #Train Op for training all layers
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        estimatorSpec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=tf.group(freeze_train_op, train_op))

    return estimatorSpec

但这不会考虑不同的迭代。如果要在不同的迭代中训练不同的变量集,并且不想停止训练并从检查点加载权重,则需要使用会话。 Estimator API不允许会话管理。