我正在使用 Tensorflow 实现 Sharpness Aware Minimization (SAM)。算法简化如下
我已经实现了第 1 步和第 2 步,但是根据下面的代码在实现第 3 步时遇到了问题
def train_step(self, data, rho=0.05, p=2, q=2):
if (1 / p) + (1 / q) != 1:
raise tf.python.framework.errors_impl.InvalidArgumentError('p, q must be specified so that 1/p + 1/q = 1')
x, y = data
# compute first backprop
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# compute neighborhoods (epsilon_hat) from first backprop
trainable_w_plus_epsilon_hat = [
w + (rho * tf.sign(loss) * (tf.pow(tf.abs(g), q-1) / tf.math.pow(tf.norm(g, ord=q), q / p)))
for w, g in zip(trainable_vars, gradients)
]
### HOW TO SET TRAINABLE WEIGHTS TO `w_plus_epsilon_hat`?
#
# TODO:
# 1. compute gradient using trainable weights from `trainable_w_plus_epsilon_hat`
# 2. update `trainable_vars` using gradient from step 1
#
#########################################################
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
是否可以使用来自 trainable_w_plus_epsilon_hat
的可训练权重来计算梯度?</p>