我正在尝试使用CNN对基于时间的生物信号进行分类。现在我的分类器说它具有91%的准确度,将打破这些信号的记录,所以我想知道我的结果是否真实。我对此结果表示怀疑,因为它的训练方式也很有趣。网上没有太多的解释,因为它的培训方式是对还是错。
#This is the code while printing results....
score = model.evaluate(test_x, test_y, verbose=1)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
#i had taken one_hot_vector so i needed to change the format here
pred_y = model.predict(test_x)
print(pred_y.shape)
predy = np.argmax(pred_y, axis=1).T
print(confusion_matrix(testy, predy))
print(classification_report(testy, predy))
这是培训过程及其输出:
Train on 980 samples, validate on 180 samples
Epoch 1/1000
- 3s - loss: 2.4064 - acc: 0.8240 - mean_absolute_error: 0.1762 - val_loss: 2.4179 - val_acc: 0.8256 - val_mean_absolute_error: 0.1737
Epoch 2/1000
- 1s - loss: 2.4962 - acc: 0.8241 - mean_absolute_error: 0.1759 - val_loss: 2.3047 - val_acc: 0.8289 - val_mean_absolute_error: 0.1716
Epoch 3/1000
- 1s - loss: 2.4027 - acc: 0.8250 - mean_absolute_error: 0.1750 - val_loss: 2.1278 - val_acc: 0.8267 - val_mean_absolute_error: 0.1736
Epoch 4/1000
- 1s - loss: 1.1608 - acc: 0.8455 - mean_absolute_error: 0.1780 - val_loss: 0.3370 - val_acc: 0.8983 - val_mean_absolute_error: 0.1792
Epoch 5/1000
- 1s - loss: 0.3330 - acc: 0.8990 - mean_absolute_error: 0.1800 - val_loss: 0.3248 - val_acc: 0.9000 - val_mean_absolute_error: 0.1798
Epoch 6/1000
- 1s - loss: 0.3263 - acc: 0.9000 - mean_absolute_error: 0.1800 - val_loss: 0.3247 - val_acc: 0.9000 - val_mean_absolute_error: 0.1798
.........
.........
.........
Epoch 993/1000
- 1s - loss: 0.0430 - acc: 0.9832 - mean_absolute_error: 0.0282 - val_loss: 0.0374 - val_acc: 0.9894 - val_mean_absolute_error: 0.0178
Epoch 994/1000
- 1s - loss: 0.0381 - acc: 0.9871 - mean_absolute_error: 0.0246 - val_loss: 0.0376 - val_acc: 0.9894 - val_mean_absolute_error: 0.0177
Epoch 995/1000
- 1s - loss: 0.0379 - acc: 0.9852 - mean_absolute_error: 0.0250 - val_loss: 0.0376 - val_acc: 0.9894 - val_mean_absolute_error: 0.0177
Epoch 996/1000
- 1s - loss: 0.0432 - acc: 0.9842 - mean_absolute_error: 0.0282 - val_loss: 0.0375 - val_acc: 0.9894 - val_mean_absolute_error: 0.0177
Epoch 997/1000
- 1s - loss: 0.0352 - acc: 0.9874 - mean_absolute_error: 0.0248 - val_loss: 0.0376 - val_acc: 0.9894 - val_mean_absolute_error: 0.0175
Epoch 998/1000
- 1s - loss: 0.0395 - acc: 0.9850 - mean_absolute_error: 0.0256 - val_loss: 0.0378 - val_acc: 0.9894 - val_mean_absolute_error: 0.0174
Epoch 999/1000
- 1s - loss: 0.0371 - acc: 0.9860 - mean_absolute_error: 0.0250 - val_loss: 0.0380 - val_acc: 0.9894 - val_mean_absolute_error: 0.0173
Epoch 1000/1000
- 1s - loss: 0.0353 - acc: 0.9873 - mean_absolute_error: 0.0243 - val_loss: 0.0379 - val_acc: 0.9894 - val_mean_absolute_error: 0.0173
模型结构如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1d_49 (Conv1D) (None, 798, 10) 310
_________________________________________________________________
conv1d_50 (Conv1D) (None, 796, 5) 155
_________________________________________________________________
max_pooling1d_25 (MaxPooling (None, 159, 5) 0
_________________________________________________________________
flatten_25 (Flatten) (None, 795) 0
_________________________________________________________________
dropout_25 (Dropout) (None, 795) 0
_________________________________________________________________
dense_25 (Dense) (None, 10) 7960
=================================================================
Total params: 8,425
Trainable params: 8,425
Non-trainable params: 0
_________________________________________________________________
测试集结果如下:
Test loss: 0.3470249831676483
Test accuracy: 0.9183333237965902
(120, 10)
[[7 0 4 0 0 0 1 0 0 0]
[0 4 1 1 1 3 0 1 0 1]
[2 0 8 0 0 0 2 0 0 0]
[1 0 1 3 2 1 1 0 2 1]
[0 0 0 0 8 2 0 0 1 1]
[0 5 0 0 0 7 0 0 0 0]
[4 0 1 0 0 0 6 0 0 1]
[1 1 0 1 0 0 0 7 1 1]
[0 1 1 0 1 0 0 0 9 0]
[1 0 2 2 0 0 2 0 0 5]]
混淆矩阵和分类结果如下:
precision recall f1-score support
0 0.44 0.58 0.50 12
1 0.36 0.33 0.35 12
2 0.44 0.67 0.53 12
3 0.43 0.25 0.32 12
4 0.67 0.67 0.67 12
5 0.54 0.58 0.56 12
6 0.50 0.50 0.50 12
7 0.88 0.58 0.70 12
8 0.69 0.75 0.72 12
9 0.50 0.42 0.45 12
micro avg 0.53 0.53 0.53 120
macro avg 0.54 0.53 0.53 120
weighted avg 0.54 0.53 0.53 120