我想创建一个权重仅在训练阶段更新的自定义图层。
这是官方文档中的方法:
from keras import backend as K
from keras.layers import Layer
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
在this github repo中 作者补充了
new_centers = self.centers - self.alpha * delta_centers
self.add_update((self.centers, new_centers), x)
其中self.centers
是权重。
我不明白为什么self.add_update在这种情况下有用。
如果我不添加self.add_update,权重将不会更新?如果不是,为什么new_centers
必须出现在更新列表中而不是输入列表中?为什么x是必需项?
从源代码开始,
self.add_update(updates, inputs)
更新:更新操作或要添加到图层的更新操作列表。
inputs:输入张量或输入张量列表以将这些更新标记为有条件的更新。如果未传递任何条件,则假定更新为无条件的。
答案 0 :(得分:1)
权重有两种:
对于可训练的权重,实际上不建议使用更新,您将把优化程序的更新与自己的更新混合在一起,这可能会导致很多问题
对于无法训练的举重,您可以做任何想做的事情。有时您需要常数,而您什么也不做,有时,您希望这些权重发生变化(但不能通过反向传播)
请注意,在该示例中,用户更新的权重是如何不可训练的:
self.centers = self.add_weight(name='centers',
shape=(10, 2),
initializer='uniform',
#UNTRAINABLE
trainable=False)
但是用户希望遵循一些规则来更新这些权重。我不知道他们在做什么(没有分析代码),但是我假设他们正在计算,例如,类似于一组图像的中心点,并且每批将在该中心一个不同的位置。他们想更新这个职位。
经典示例是BatchNormalization
层。除了具有可训练的scale
和bias
权重用于重新缩放输出,它们还具有mean
和variance
权重。这些是数据的统计属性,需要每批更新。
您不是在训练“均值”或“方差”,但是每一批数据都会更新这些值。
这很晦涩,位于Keras代码的深处。
我们需要更新操作,因此请确保self.centers
将为每个批次具有新值,否则将没有。
我们在层中使用self.add_update
来注册该变量应被更新。 (我们也在自定义优化器中执行类似的操作,优化器包含通过反向传播进行权重的更新)
稍后在source code for training the model中,Keras将收集所有这些注册的更新并进行训练。在其中的某个位置,这些更新将应用于vars:
#inside a training function from keras
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = (self.updates + #probably the updates registered in layers
training_updates + #the updates registered in optimizers
self.metrics_updates) #don't know....
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(
inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)