如何在使用Keras(tensorflow)的神经网络回归中限制预测输出的总和

时间:2018-08-08 11:50:28

标签: machine-learning neural-network keras regression

我正在训练keras中的神经网络(python,后端:tensorflow)以用作回归。因此,我的输出层不包含激活函数,并且我将均方误差用作损失函数。

我的问题是:我想确保所有输出估计值的总和(几乎)等于所有实际标签的总和。

我的意思是:我不仅要确保每个训练示例i的(y_real)^ i〜(y_predict)^ i,还要确保sum(y_real)= sum(y_predict),求和我全部常规的线性回归使添加此限制变得非常简单,但是对于神经网络,我看不到任何类似的东西。我可以将最终结果乘以sum(y_real)/ sum(y_predict),但是如果我不想损害单个预测,恐怕这不是理想的方法。

我还有什么其他选择?

(我无法共享我的数据,也无法轻松地用其他数据重现该问题,但这是按要求使用的代码:)

from keras.models import Sequential
from keras.layers import Dense

model = Sequential()
model.add(Dense(128, activation = 'relu', input_dim = 459))
model.add(Dense(32, activation = 'relu'))
model.add(Dense(1))

model.compile(loss = 'mean_squared_error',
              optimizer = 'adam')

model.fit(X_train, Y_train, epochs = 5, validation_data = (X_val, 
Y_val), batch_size = 128)

1 个答案:

答案 0 :(得分:1)

从优化的角度来看,您想对问题引入等式约束。您正在寻找网络权重,以使预测y1_hat, y2_hat and y3_hat最小化均方误差,而不会加上标签y1, y2, y3。此外,您希望以下内容成立:

sum(y1, y2, y3) = sum(y1_hat, y2_hat, y3_hat)

由于您使用的是神经网络,因此您希望施加这种约束,以便仍然可以使用反向传播来训练网络。

做到这一点的一种方法是在损失函数中添加一个术语,以惩罚sum(y1, y2, y3)sum(y1_hat, y2_hat, y3_hat)之间的差异。

最小工作示例:

import numpy as np
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model

# Some random training data and labels
features = np.random.rand(100, 20)
labels = np.random.rand(100, 3)

# Simple neural net with three outputs
input_layer = Input((20,))
hidden_layer = Dense(16)(input_layer)
output_layer = Dense(3)(hidden_layer)

# Model
model = Model(inputs=input_layer, outputs=output_layer)

# Write a custom loss function
def custom_loss(y_true, y_pred):
    # Normal MSE loss
    mse = K.mean(K.square(y_true-y_pred), axis=-1)
    # Loss that penalizes differences between sum(predictions) and sum(labels)
    sum_constraint = K.square(K.sum(y_pred, axis=-1) - K.sum(y_true, axis=-1))

    return(mse+sum_constraint)

# Compile with custom loss
model.compile(loss=custom_loss, optimizer='sgd')
model.fit(features, labels, epochs=1, verbose=1)

请注意,这以“软”方式施加了约束,而不是硬约束。您仍然会得到偏差,但是网络应该以较小的方式学习权重。