是否可以显示批次总和而不是批次平均数作为keras损失?

时间:2020-04-26 07:29:38

标签: python tensorflow keras deep-learning

我有一项回归任务,正在使用欧几里得距离来测量拟合度。我不想显示均方误差作为损失,而是要显示平方和。也就是说,我想求平方误差项的总和,而除以示例数。

在批处理级别,我可以通过定义自定义损失来实现此目的(也许我可以直接使用tf.keras.losses.MeanSquareError

class CustomLoss(tf.keras.losses.Loss):
    def call(self, Y_true, Y_pred):
        return tf.reduce_sum(tf.math.abs(Y_true-Y_pred) ** 2, axis=-1)

target_loss=CustomLoss(reduction=tf.keras.losses.Reduction.SUM)

这将为每个示例计算平方误差,然后指示TensorFlow对示例进行SUM运算,以计算批次损失,而不是默认的SUM_OVER_BATCH_SIZE(不应从字面上读取,而应作为分数读取,即, SUM / BATCH_SIZE

我的问题是,在一个纪元水平上,Keras取这些和,然后计算跨步(批)的平均值,以报告该纪元的损失。 我如何让Keras计算批次的总和而不是均值?

2 个答案:

答案 0 :(得分:0)

您将必须编写一个Custom Callback,它将在每一批之后将损失附加到列表中(如共享链接文档中所示)。

实施 on_epoch_end以获得列表中所有值的总和(您在其中添加了所有批次损失)

如果要使所有批次的损失总和最小化,请使用K.Function API。 Full implementation

答案 1 :(得分:0)

您可以对 tf.keras.metric.Metric 中的批次求和,如下所示,但现在 2.4.x 中有一个未决问题(请参阅 this GitHub issue),不过您可以尝试使用 2.3.2,

class AddAllOnes(tf.keras.metrics.Metric):
  """ A simple metric that adds all the one's in current batch and suppose to return the total ones seen at every end of batch"""
    def __init__(self, name="add_all_ones", **kwargs):
        super(AddAllOnes, self).__init__(name=name, **kwargs)
        self.total = self.add_weight(name="total", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):    
        self.total.assign_add(tf.cast(tf.reduce_sum(y_true), dtype=tf.float32))
        
    def result(self):
        print('')
        print('inside result...', self.total)
        return self.total

X_train = np.random.random((512, 8))
y_train = np.random.randint(0, 2, (512, 1))

K.clear_session()
model_inputs = Input(shape=(8,))
model_unit = Dense(256, activation='linear', use_bias=False)(model_inputs)
model_unit = BatchNormalization()(model_unit)
model_unit = Activation('sigmoid')(model_unit)
model_outputs = Dense(1, activation='sigmoid')(model_unit)
optim = Adam(learning_rate=0.001)
model = Model(inputs=model_inputs, outputs=model_outputs)
model.compile(loss='binary_crossentropy', optimizer=optim, metrics=[AddAllOnes()], run_eagerly=True)
model.fit(X_train, y_train, verbose=1, batch_size=32)