使用修改的权重计算模型的梯度

时间:2020-12-20 15:55:17

标签: python tensorflow keras deep-learning tensorflow2.0

我正在使用 Tensorflow 实现 Sharpness Aware Minimization (SAM)。算法简化如下

  1. 使用当前权重 W 计算梯度
  2. 根据论文中的方程计算ε
  3. 使用权重 W + ε 计算梯度
  4. 使用第 3 步中的梯度更新模型

我已经实现了第 1 步和第 2 步,但是根据下面的代码在实现第 3 步时遇到了问题

def train_step(self, data, rho=0.05, p=2, q=2):
    if (1 / p) + (1 / q) != 1:
        raise tf.python.framework.errors_impl.InvalidArgumentError('p, q must be specified so that 1/p + 1/q = 1')
    x, y = data
        
    # compute first backprop
    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)
        loss = self.compiled_loss(y, y_pred)
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
        
    # compute neighborhoods (epsilon_hat) from first backprop
    trainable_w_plus_epsilon_hat = [
        w + (rho * tf.sign(loss) * (tf.pow(tf.abs(g), q-1) / tf.math.pow(tf.norm(g, ord=q), q / p)))
        for w, g in zip(trainable_vars, gradients)
    ]
        
    ### HOW TO SET TRAINABLE WEIGHTS TO `w_plus_epsilon_hat`?
    #
    # TODO:
    #     1. compute gradient using trainable weights from `trainable_w_plus_epsilon_hat`
    #     2. update `trainable_vars` using gradient from step 1
    #
    #########################################################

    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}

是否可以使用来自 trainable_w_plus_epsilon_hat 的可训练权重来计算梯度?<​​/p>

0 个答案:

没有答案
相关问题