我的自定义Keras图层中的add_update()不会更新权重

时间:2020-04-07 02:53:09

标签: python tensorflow keras

因此,我正在实施中心损失:https://ydwen.github.io/papers/WenECCV16.pdf,并且在更新图层权重时遇到问题,这意味着更新中心损失中的中心。当我像这样<input type="text" placeholder="Type 'Help1' for actions"> <button>Confirm</button> <div></div>打印我的class_centers时,它们永远不会改变。当我打印其他变量时,它们看起来还不错,所以我唯一想到的问题是add_update()并没有执行应做的工作。

自定义层:

tf.print(self.class_centers, summarize=-1, output_stream='file:///tensors.txt')

最后的损失是:

class CenterLossLayer(Layer):
    def __init__(self, alpha=0.5, **kwargs):
        self.alpha = alpha  
        super(CenterLossLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        print('Center loss input 1 (feature_size): ', input_shape[0][1])
        print('Center loss input 2 (num_classes): ', input_shape[1][1])
        self.class_centers = self.add_weight(name='class_centers',
                                       shape=(input_shape[1][1], input_shape[0][1]),
                                       initializer='uniform',
                                       trainable=False)
        super(CenterLossLayer, self).build(input_shape)

    def call(self, x, mask=None):
        embeddings, one_hots = x
        tf.print(self.class_centers, summarize=-1, output_stream='file:///tensors.txt')

        batch_centers = K.dot(one_hots, self.class_centers)
        batch_delta = batch_centers - embeddings

        class_delta = K.dot(K.transpose(one_hots), batch_delta)
        counts = K.sum(K.transpose(one_hots), axis=1, keepdims=True) + 1
        class_delta = class_delta / counts
        class_delta = K.in_train_phase(self.alpha * class_delta, 0 * class_delta)

        updated_class_centers = self.class_centers - class_delta
        self.add_update((self.class_centers, updated_class_centers), x[0])

        losses = K.sum(K.square(embeddings - batch_centers), axis=1, keepdims=True)

        return losses

    def compute_output_shape(self, input_shape):
        return (input_shape[1][0], )

其中def batch_mean_loss(y_true, y_pred): return K.mean(y_pred, axis=0) 是CenterLossLayer的y_pred

奇怪的是,即使以为中心没有更新,每个阶段的中心损失都在下降,而最终模型比只训练有Softmax损失的模型更好。

2 个答案:

答案 0 :(得分:1)

因此,我检查了add_update()在BatchNormalization层中的使用方式:

self.add_update([K.moving_average_update(self.moving_mean,
                                                 mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 variance,
                                                 self.momentum)],
                        inputs)

问题是方法add_update()的第一个参数是“ updates:Update op”,而moving_average_update()返回“操作操作以更新变量”。因此,我猜想add_update()需要某种操作,而moving_average_update()返回该操作。我不知道如何创建此操作,所以我做了:

self.add_update(K.moving_average_update(self.class_centers, updated_class_centers, 0.0), x)

因此它的功能就像用self.class_centers替换updated_class_centers一样。

即使认为它可行,如果有人知道如何正确执行此操作,我将不胜感激。

答案 1 :(得分:0)

看起来您应该执行以下操作:

class ComputeSum(keras.layers.Layer):
  def __init__(self, input_dim):
    super(ComputeSum, self).__init__()
    self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)

  def call(self, inputs):
    self.total.assign_add(tf.reduce_sum(inputs, axis=0))
    return self.total

摘录自https://keras.io/guides/making_new_layers_and_models_via_subclassing/#layers-can-have-nontrainable-weighto