在 TF2 keras 中,我使用 tensorflow.keras.losses.MeanSquaredError 作为损失函数训练了一个自动编码器。现在,我想通过使用另一个损失函数来进一步训练这个模型,特别是 tensorflow.keras.losses.KLDivergence。这样做的原因是最初进行无监督学习是为了表示学习。然后,有了生成的嵌入,我可以将它们聚类并使用这些聚类进行自我监督,即标签,启用第二个监督损失并进一步改进模型。
这本身不是迁移学习,因为没有向模型添加新层,只是改变了损失函数,模型继续训练。
我尝试过的是使用预训练模型和 MSE 损失作为新模型的属性:
class ClusterBooster(tf.keras.Model):
def __init__(self, base_model, centers):
super(ClusterBooster, self).__init__()
self.pretrained = base_model
self.centers = centers
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self.compiled_loss(self.P, self.Q, regularization_losses=self.losses)
# Compute gradients
gradients = tape.gradient(loss, self.trainable_variables)
# Update weights
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {m.name: m.result() for m in self.metrics}
其中损失是分布 P 和 Q 之间的 KL 损失。分布在回调函数中计算,而不是在模型 train_step 中计算,因为我需要访问当前时期(P 每 5 个时期更新一次,而不是在每个时期更新) :
def on_epoch_begin(self, epoch, logs=None):
z = self.model.pretrained.embed(self.feature, training=True)
z = tf.reshape(z, [tf.shape(z)[0], 1, tf.shape(z)[1]]) # reshape for broadcasting
# CALCULATE Q FOR EVERY EPOCH
partial = tf.math.pow(tf.norm(z - self.model.centers, axis=2, ord='euclidean'), 2)
nominator = 1 / (1 + partial)
denominator = tf.math.reduce_sum(1 / (1 + partial))
self.model.Q = nominator / denominator
# CALCULATE P EVERY 5 EPOCHS TO AVOID INSTABILITY
if epoch % 5 == 0:
partial = tf.math.pow(self.model.Q, 2) / tf.math.reduce_sum(self.model.Q, axis=1, keepdims=True)
nominator = partial
denominator = tf.math.reduce_sum(partial, axis=0)
self.model.P = nominator / denominator
但是,当执行 apply_gradients() 时,我得到:
ValueError: No gradients provided for any variable: ['dense/kernel:0', 'dense/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0', 'dense_3/kernel:0', 'dense_3/bias:0']
我认为这是由于预训练模型未设置为在新模型内部的某处进一步训练(仅调用 embed() 方法,该方法不训练模型)。这是一种正确的方法,我只是遗漏了一些东西还是有更好的方法?
答案 0 :(得分:0)
似乎在回调中发生的任何计算都不会跟踪梯度计算和权重更新。因此,这些计算应该放在自定义模型类 (ClusterBooster) 的 train_step() 函数中。
假设我无法访问 ClusterBooster 的 train_step() 函数内的时期数,我创建了一个没有 Model 类的自定义训练循环,我可以在其中使用纯 python 代码(这是热切计算的)。