我正在尝试使用神经网络进行回归以预测146个输入要素的单个输出。
我在所有输入和输出上应用了标准缩放。
我在训练后监视平均绝对误差,并且在训练,验证和测试集上过高(我什至没有过拟合)。
我怀疑这是因为输出变量非常不平衡(请参见直方图)。 从直方图中可以看到,大多数样本都在0附近分组,但在-5附近还有另一小组样本。
Histogram of the imbalanced output
这是模型创建代码:
input = Input(batch_shape=(None, X.shape[1]))
layer1 = Dense(20, activation='relu')(input)
layer1 = Dropout(0.3)( layer1)
layer1 = BatchNormalization()(layer1)
layer2 = Dense(5, activation='relu',
kernel_regularizer='l2')(layer1)
layer2 = Dropout(0.3)(layer2)
layer2 = BatchNormalization()(layer2)
out_layer = Dense(1, activation='linear')(layer2)
model = Model(inputs=input, outputs=out_layer)
model.compile(loss='mean_squared_error', optimizer=optimizers.adam()
, metrics=['mae'])
这是模型摘要:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 146) 0
_________________________________________________________________
dense_1 (Dense) (None, 20) 2940
_________________________________________________________________
dropout_1 (Dropout) (None, 20) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 20) 80
_________________________________________________________________
dense_2 (Dense) (None, 5) 105
_________________________________________________________________
dropout_2 (Dropout) (None, 5) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 5) 20
_________________________________________________________________
dense_3 (Dense) (None, 1) 6
=================================================================
Total params: 3,151
Trainable params: 3,101
Non-trainable params: 50
_________________________________________________________________
从实际模型预测来看,较大的误差主要发生在真实输出值约为-5的样本(较小的样本组)中。
我为超参数尝试了许多配置,但错误仍然很高。
我看到了很多关于对不平衡数据进行神经网络分类的建议,但是回归可以做什么? 在我看来,回归神经网络无法正确学习这一点很奇怪。我在做什么错了?
答案 0 :(得分:3)
从直方图中看,很少有非零输出。这类似于分类问题,在分类问题中,我们试图预测稀有类别,因为就损失函数而言,有效的策略只是猜测最常见的类别-在这种情况下,您的模态值为零。
您应该围绕人们如何预测稀有事件或在某些类别很少时对输入进行分类进行一些研究。例如。该讨论可能会有所帮助:https://www.reddit.com/r/MachineLearning/comments/412wpp/predicting_rare_events_how_to_prevent_machine/
您可以尝试使用的一些策略
希望有帮助!
答案 1 :(得分:1)
在我看来,您具有正态分布,并且标准偏差非常小。在这种情况下,这应该与其他概率分布一样训练。