我正在做新闻多标签分类的任务。我在喀拉拉邦使用了CNN,并且在我的代码中尝试了l2
,shuffle
,BatchNormalization
,earlystopping
和Dropout
,但看起来仍然过分适合验证准确性在0.5
附近波动,并且验证损失陷于困境。
这是我的代码
epochs = 50
batch_size = 150
model = keras.Sequential()
model.add(keras.layers.Embedding(MAX_NB_WORDS, EMBEDDING_DIM, input_length=X.shape[1]))
model.add(keras.layers.SpatialDropout1D(0.4))
model.add(keras.layers.Conv1D(256, 7, activation='relu', kernel_regularizer=keras.regularizers.l2(l=0.1)))
model.add(keras.layers.Dropout(0.4))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.Conv1D(128, 5, activation='relu', kernel_regularizer=keras.regularizers.l2(l=0.1)))
model.add(keras.layers.Dropout(0.4))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.Conv1D(64, 3, activation='relu', kernel_regularizer=keras.regularizers.l2(l=0.1)))
model.add(keras.layers.Dropout(0.4))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.Dense(128,kernel_regularizer=keras.regularizers.l2(l=0.1), activation='relu'))
model.add(keras.layers.Dropout(0.4))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(4, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adam(lr=0.001), metrics=['accuracy'])
model.fit(X, Y, epochs=epochs,shuffle=True, batch_size=batch_size,validation_split=0.05,callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta=0.0001)])
结果如下:
Train on 3341 samples, validate on 176 samples
Epoch 1/50
3341/3341 [==============================] - 17s 5ms/step - loss: 39.5163 - acc: 0.3714 - val_loss: 25.3390 - val_acc: 0.3409
Epoch 2/50
3341/3341 [==============================] - 12s 4ms/step - loss: 19.8568 - acc: 0.4101 - val_loss: 12.8626 - val_acc: 0.1307
Epoch 3/50
3341/3341 [==============================] - 11s 3ms/step - loss: 9.5797 - acc: 0.4259 - val_loss: 6.8670 - val_acc: 0.4489
Epoch 4/50
3341/3341 [==============================] - 12s 4ms/step - loss: 5.2375 - acc: 0.4633 - val_loss: 4.1790 - val_acc: 0.5000
Epoch 5/50
3341/3341 [==============================] - 12s 4ms/step - loss: 3.1856 - acc: 0.5142 - val_loss: 2.9318 - val_acc: 0.5000
Epoch 6/50
3341/3341 [==============================] - 12s 4ms/step - loss: 2.1090 - acc: 0.5798 - val_loss: 2.2959 - val_acc: 0.5000
Epoch 7/50
3341/3341 [==============================] - 12s 3ms/step - loss: 1.4496 - acc: 0.6827 - val_loss: 1.9783 - val_acc: 0.5000
Epoch 8/50
3341/3341 [==============================] - 12s 4ms/step - loss: 1.0001 - acc: 0.7896 - val_loss: 1.8255 - val_acc: 0.1080
Epoch 9/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.7020 - acc: 0.8599 - val_loss: 1.6645 - val_acc: 0.3920
Epoch 10/50
3341/3341 [==============================] - 12s 3ms/step - loss: 0.4969 - acc: 0.9045 - val_loss: 1.5975 - val_acc: 0.4489
Epoch 11/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.3607 - acc: 0.9431 - val_loss: 1.6214 - val_acc: 0.4886
Epoch 12/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.2699 - acc: 0.9662 - val_loss: 1.4582 - val_acc: 0.4659
Epoch 13/50
3341/3341 [==============================] - 12s 3ms/step - loss: 0.2169 - acc: 0.9776 - val_loss: 1.5432 - val_acc: 0.4773
Epoch 14/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.1770 - acc: 0.9826 - val_loss: 1.4467 - val_acc: 0.4773
Epoch 15/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.1445 - acc: 0.9856 - val_loss: 1.5100 - val_acc: 0.4943
Epoch 16/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.1414 - acc: 0.9856 - val_loss: 1.4172 - val_acc: 0.4830
Epoch 17/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.1260 - acc: 0.9880 - val_loss: 1.4200 - val_acc: 0.4830
Epoch 18/50
3341/3341 [==============================] - 12s 4ms/step - loss: 0.1103 - acc: 0.9928 - val_loss: 1.4191 - val_acc: 0.4830
Epoch 19/50
3341/3341 [==============================] - 12s 3ms/step - loss: 0.0952 - acc: 0.9916 - val_loss: 1.4775 - val_acc: 0.4602
我该怎么做才能解决此问题并提高验证准确性?