验证准确性不会改善CNN

时间:2020-09-20 12:30:36

标签: python machine-learning keras deep-learning conv-neural-network

我对深度学习还很陌生,现在正在尝试根据EEG数据预测消费者的选择。整个数据集包含1045个EEG记录,每个记录都有一个相应的标签,表示产品的喜欢或不喜欢。班级分布如下(44%的喜欢和56%的不喜欢)。我读到卷积神经网络适合处理原始EEG数据,因此我尝试使用以下结构实现基于keras的网络:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(full_data, target, test_size=0.20, random_state=42)

y_train = np.asarray(y_train).astype('float32').reshape((-1,1))
y_test = np.asarray(y_test).astype('float32').reshape((-1,1))


# X_train.shape = ((836, 512, 14))
# y_train.shape = ((836, 1))

from keras.optimizers import Adam
from keras.optimizers import SGD
from keras.layers import MaxPooling1D
model = Sequential()

model.add(Conv1D(16, kernel_size=3, activation="relu", input_shape=(512,14)))

model.add(MaxPooling1D())

model.add(Conv1D(8, kernel_size=3, activation="relu"))

model.add(MaxPooling1D())

model.add(Flatten())

model.add(Dense(1, activation="sigmoid"))

model.compile(optimizer=Adam(lr = 0.001), loss='binary_crossentropy', metrics=['accuracy'])

model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=20, batch_size = 64)

当我拟合模型时,以下输出完全不会改变验证准确性:


Epoch 1/20
14/14 [==============================] - 0s 32ms/step - loss: 292.6353 - accuracy: 0.5383 - val_loss: 0.7884 - val_accuracy: 0.5407
Epoch 2/20
14/14 [==============================] - 0s 7ms/step - loss: 1.3748 - accuracy: 0.5598 - val_loss: 0.8860 - val_accuracy: 0.5502
Epoch 3/20
14/14 [==============================] - 0s 6ms/step - loss: 1.0537 - accuracy: 0.5598 - val_loss: 0.7629 - val_accuracy: 0.5455
Epoch 4/20
14/14 [==============================] - 0s 6ms/step - loss: 0.8827 - accuracy: 0.5598 - val_loss: 0.7010 - val_accuracy: 0.5455
Epoch 5/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7988 - accuracy: 0.5598 - val_loss: 0.8689 - val_accuracy: 0.5407
Epoch 6/20
14/14 [==============================] - 0s 6ms/step - loss: 1.0221 - accuracy: 0.5610 - val_loss: 0.6961 - val_accuracy: 0.5455
Epoch 7/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7415 - accuracy: 0.5598 - val_loss: 0.6945 - val_accuracy: 0.5455
Epoch 8/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7381 - accuracy: 0.5574 - val_loss: 0.7761 - val_accuracy: 0.5455
Epoch 9/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7326 - accuracy: 0.5598 - val_loss: 0.6926 - val_accuracy: 0.5455
Epoch 10/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7338 - accuracy: 0.5598 - val_loss: 0.6917 - val_accuracy: 0.5455
Epoch 11/20
14/14 [==============================] - 0s 7ms/step - loss: 0.7203 - accuracy: 0.5610 - val_loss: 0.6916 - val_accuracy: 0.5455
Epoch 12/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7192 - accuracy: 0.5610 - val_loss: 0.6914 - val_accuracy: 0.5455
Epoch 13/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7174 - accuracy: 0.5610 - val_loss: 0.6912 - val_accuracy: 0.5455
Epoch 14/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7155 - accuracy: 0.5610 - val_loss: 0.6911 - val_accuracy: 0.5455
Epoch 15/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7143 - accuracy: 0.5610 - val_loss: 0.6910 - val_accuracy: 0.5455
Epoch 16/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7129 - accuracy: 0.5610 - val_loss: 0.6909 - val_accuracy: 0.5455
Epoch 17/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7114 - accuracy: 0.5610 - val_loss: 0.6907 - val_accuracy: 0.5455
Epoch 18/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7103 - accuracy: 0.5610 - val_loss: 0.6906 - val_accuracy: 0.5455
Epoch 19/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7088 - accuracy: 0.5610 - val_loss: 0.6906 - val_accuracy: 0.5455
Epoch 20/20
14/14 [==============================] - 0s 6ms/step - loss: 0.7075 - accuracy: 0.5610 - val_loss: 0.6905 - val_accuracy: 0.5455

在此先感谢您的见解!

1 个答案:

答案 0 :(得分:0)

您遇到的现象称为underfitting。当我们的培训数据质量不足,或者您的网络体系结构太小而无法学习问题时,就会发生这种情况。

尝试规范化您的输入数据,并尝试使用不同的网络体系结构,学习率和激活功能。

正如@Muhammad Shahzad在其评论中所述,在扁平化之后添加一些密集层将是您应该尝试的具体体系结构改编。