神经网络损失值不变

时间:2020-03-20 00:01:44

标签: python machine-learning keras neural-network

我是深度学习的新手,因此我制作了这个模型,以训练我的数据,我尝试了许多组合,添加层,更改激活函数,更改损失函数,但损失并未减少。 寻求帮助的人。

我的training_data包含1000个样本:1000个原始数据,以及20列所有数字,输出:4个数字的列表 这是我的模特:

from keras import models
from keras.models import Sequential


from keras import layers
from keras.layers import Dense
from keras.layers import Flatten , Dropout
from keras.optimizers import SGD
from keras.callbacks import EarlyStopping
from sklearn.preprocessing import StandardScaler
from keras import optimizers

scaler = StandardScaler()
input_shape = x_train[0].shape
x_train_std = scaler.fit_transform(x_train)


model = Sequential()
model.add(layers.Dense(32, activation='sigmoid' , input_shape=input_shape))
model.add(Dropout(0.1))
model.add(layers.Dense(20, activation='sigmoid' ))
model.add(Dropout(0.1))
model.add(layers.Dense(15, activation='sigmoid' ))
model.add(Dropout(0.1))

model.add(layers.Dense(4, activation='softmax'))
#sgd = optimizers.SGD(lr=0.00001, decay=1e-6, momentum=0.85, nesterov=True)
#opt = SGD(lr=0.1, nesterov=True)
sgd = optimizers.SGD(lr=0.01, momentum=0.87, nesterov=True)
model.compile(loss='mean_squared_error',
              optimizer=sgd)
es = EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(x_train_std, y_train , validation_split=0.1, epochs=100, batch_size=1 , callbacks = [es])#,

1 个答案:

答案 0 :(得分:0)

由于您处于回归设置中,因此应softmax用作最后一层的激活函数-linear激活(如果存在,则为默认激活)未定义)应在此处使用。

对于中间层,也强烈不建议使用Sigmoid激活-而是使用relu

此外,默认情况下,您不应该使用辍学-在没有辍学的情况下开始,并且只有在提高验证性能的情况下才添加。

总而言之,这是初学者模型的外观:

model = Sequential()
model.add(layers.Dense(32, activation='relu' , input_shape=input_shape))
# model.add(Dropout(0.1))
model.add(layers.Dense(20, activation='relu' ))
# model.add(Dropout(0.1))
model.add(layers.Dense(15, activation='relu' ))
# model.add(Dropout(0.1))
model.add(layers.Dense(4)) # default activation='linear'

仅在需要时(不必完全相同),您才可以取消注释辍学层。

尝试使用Adam优化器也是第一种方法的好主意:

model.compile(loss='mean_squared_error',
              optimizer=keras.optimizers.Adam())

最后,您绝对应该增加batch_size(对于初学者,请尝试64或128)。