模型似乎过度使用Optimizer.minimize()而不是tf.contrib.layers.optimize_loss()

时间:2017-12-20 14:20:49

标签: tensorflow machine-learning neural-network conv-neural-network

当我像这样创建train_op时:

train_op = tf.contrib.layers.optimize_loss(
    loss=loss,
    global_step=tf.contrib.framework.get_global_step(),
    learning_rate=params['learning_rate'],
    optimizer='Adam'
)

我得到了一个在验证和测试集上表现良好的工作网络。

如果我只使用这样的minimize()方法:

optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])
train_op = optimizer.minimize(
    loss=loss,
    global_step=tf.train.get_global_step()
)

即使在1000步之后的第一次验证中,我的结果(精确度,召回率,损失)也会变得更糟,过了一段时间它似乎完全过度装配(验证损失或多或少是恒定的,是100倍的火车损失,但是精确和召回崩溃)

我创建了一个功能,它是contrib one的清理版本,与两个标记位置的直接Optimizer.minimize()不同:

def make_train_op(loss, optimizer, global_step):
    with tf.variable_scope(None, "OptimizeLoss", [loss, global_step]):

        # ==========================================
        # this part is extra comparing to minimize()
        update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
        if update_ops:
            with tf.control_dependencies([update_ops]):
                loss = tf.identity(loss)
        # ==========================================

        gradients = optimizer.calculate_gradients(
            loss,
            tf.trainable_variables()
        )

        grad_updates = optimizer.apply_gradients(
            gradients,
            global_step=global_step,
            name="train")

        # ==========================================
        # so is this one
        with tf.control_dependencies([grad_updates]):
            train_op = tf.identity(loss)
        # ==========================================
        return train_op

验证再次表现良好。在所有情况下,培训看起来或多或少相同(和健康)。网络是相对简单的CNN / batchnorm / dropout / maxpool混合交叉熵损失。

我理解这一点的方法是,某些操作是图形的一部分,不会显示为丢失的依赖关系,但这是计算渐变所需的。这怎么可能呢?如果这是正常情况,为什么这两个片段不是核心的一部分?我是否应该在构建模型时做一些不同的事情以避免需要这种依赖强制?

1 个答案:

答案 0 :(得分:2)

问题在于batchnorm更新操作,实际上是Measurement Protocol

  

注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)