如何提高Keras中CNN-LSTM分类器的准确性

时间:2018-03-17 09:52:51

标签: python tensorflow machine-learning deep-learning keras

我正在开发keras中的CNN-LSTM分类器,它将长度为500的DNA序列作为输入,并尝试将序列分类为1类或0类。我从中获取了数据集和所有最佳基准研究论文[https://file.scirp.org/pdf/JBiSE_2016042713533805.pdf] 我的目标是超越他们使用CNN结构的论文中产生的基准。我的尝试是尝试结合CNN-LSTM架构来尝试提高它们达到的准确度。

问题在于它们的精确度始终比这种架构高出2-3%。我的模型是否正确,是否可以通过任何方式进行改进以获得更好的结果,这样至少可以提高我的准确性。

我的代码如下:

# Each sequence is read and divided into 3-grams and then one hot encoded as in the paper.
X_train_seqs,y_train,X_test_seqs,y_test,total_no_seqs = dataset.load_data()
X_train,y_train,X_test,y_test = dataset.one_hot_encode_128(X_train_seqs,y_train,X_test_seqs,y_test)

model = Sequential()
model.add(Conv1D(filters=496, kernel_size=2, input_shape=(497,128), padding='same', activation='relu'))
model.add(MaxPooling1D(pool_size=2))
model.add(Dropout(0.5))
model.add((LSTM(units=100,dropout=0.3, recurrent_dropout=0.3)))
model.add(Dropout(0.5))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

model.fit(X_train,y_train, validation_data=(X_test,y_test), nb_epoch=90, batch_size=64, callbacks=callbacks_list, verbose=2)

例如对于这个数据集:H3K9ac [论文的结果大约是79-79.2]

我的结果如下:

Train on 25003 samples, validate on 2778 samples
Epoch 1/1000
 - 776s - loss: 0.6883 - acc: 0.5476 - val_loss: 0.6859 - val_acc: 0.5436
Epoch 2/1000
 - 671s - loss: 0.6736 - acc: 0.5767 - val_loss: 0.6435 - val_acc: 0.6483
Epoch 3/1000
 - 678s - loss: 0.6457 - acc: 0.6268 - val_loss: 0.6120 - val_acc: 0.6976
Epoch 4/1000
 - 681s - loss: 0.6193 - acc: 0.6630 - val_loss: 0.6148 - val_acc: 0.6595
Epoch 5/1000
 - 661s - loss: 0.6042 - acc: 0.6741 - val_loss: 0.6056 - val_acc: 0.6692
Epoch 6/1000
 - 671s - loss: 0.5925 - acc: 0.6908 - val_loss: 0.5632 - val_acc: 0.7221
Epoch 7/1000
 - 666s - loss: 0.5803 - acc: 0.7051 - val_loss: 0.5606 - val_acc: 0.7261
Epoch 8/1000
 - 680s - loss: 0.5744 - acc: 0.7106 - val_loss: 0.5521 - val_acc: 0.7322
Epoch 9/1000
 - 662s - loss: 0.5603 - acc: 0.7238 - val_loss: 0.5433 - val_acc: 0.7351
Epoch 10/1000
 - 661s - loss: 0.5542 - acc: 0.7258 - val_loss: 0.5339 - val_acc: 0.7433
Epoch 11/1000
 - 664s - loss: 0.5451 - acc: 0.7315 - val_loss: 0.5287 - val_acc: 0.7477
Epoch 12/1000
 - 663s - loss: 0.5403 - acc: 0.7357 - val_loss: 0.5262 - val_acc: 0.7462
Epoch 13/1000
 - 673s - loss: 0.5355 - acc: 0.7401 - val_loss: 0.5176 - val_acc: 0.7523
Epoch 14/1000
 - 669s - loss: 0.5273 - acc: 0.7452 - val_loss: 0.5146 - val_acc: 0.7520
Epoch 15/1000
 - 673s - loss: 0.5234 - acc: 0.7458 - val_loss: 0.5155 - val_acc: 0.7505
Epoch 16/1000
 - 673s - loss: 0.5196 - acc: 0.7510 - val_loss: 0.5124 - val_acc: 0.7552
Epoch 17/1000
 - 667s - loss: 0.5136 - acc: 0.7550 - val_loss: 0.5073 - val_acc: 0.7581
Epoch 18/1000
 - 657s - loss: 0.5149 - acc: 0.7518 - val_loss: 0.5055 - val_acc: 0.7556
Epoch 19/1000
 - 671s - loss: 0.5091 - acc: 0.7556 - val_loss: 0.5089 - val_acc: 0.7577
Epoch 20/1000
 - 658s - loss: 0.5060 - acc: 0.7593 - val_loss: 0.5066 - val_acc: 0.7585
Epoch 21/1000
 - 679s - loss: 0.5027 - acc: 0.7609 - val_loss: 0.5084 - val_acc: 0.7613
Epoch 22/1000
 - 681s - loss: 0.4989 - acc: 0.7636 - val_loss: 0.5113 - val_acc: 0.7477
Epoch 23/1000
 - 657s - loss: 0.4987 - acc: 0.7647 - val_loss: 0.5146 - val_acc: 0.7527
Epoch 24/1000
 - 659s - loss: 0.4963 - acc: 0.7643 - val_loss: 0.5032 - val_acc: 0.7631
Epoch 25/1000
 - 657s - loss: 0.4919 - acc: 0.7701 - val_loss: 0.5052 - val_acc: 0.7639
Epoch 26/1000
 - 668s - loss: 0.4900 - acc: 0.7674 - val_loss: 0.4989 - val_acc: 0.7649
Epoch 27/1000
 - 666s - loss: 0.4860 - acc: 0.7728 - val_loss: 0.5010 - val_acc: 0.7642
Epoch 28/1000
 - 657s - loss: 0.4871 - acc: 0.7718 - val_loss: 0.5091 - val_acc: 0.7570
Epoch 29/1000
 - 668s - loss: 0.4804 - acc: 0.7741 - val_loss: 0.5115 - val_acc: 0.7570
Epoch 30/1000
 - 666s - loss: 0.4826 - acc: 0.7762 - val_loss: 0.4930 - val_acc: 0.7696
Epoch 31/1000
 - 665s - loss: 0.4755 - acc: 0.7779 - val_loss: 0.5142 - val_acc: 0.7491
Epoch 32/1000
 - 662s - loss: 0.4753 - acc: 0.7785 - val_loss: 0.4968 - val_acc: 0.7646
Epoch 33/1000
 - 687s - loss: 0.4707 - acc: 0.7821 - val_loss: 0.5078 - val_acc: 0.7595
Epoch 34/1000
 - 667s - loss: 0.4709 - acc: 0.7809 - val_loss: 0.4968 - val_acc: 0.7621
Epoch 35/1000
 - 667s - loss: 0.4671 - acc: 0.7822 - val_loss: 0.5021 - val_acc: 0.7599
Epoch 36/1000
 - 666s - loss: 0.4642 - acc: 0.7849 - val_loss: 0.5074 - val_acc: 0.7617
Epoch 37/1000
 - 659s - loss: 0.4651 - acc: 0.7865 - val_loss: 0.4943 - val_acc: 0.7689
Epoch 38/1000
 - 657s - loss: 0.4607 - acc: 0.7859 - val_loss: 0.4935 - val_acc: 0.7639
Epoch 39/1000
 - 663s - loss: 0.4627 - acc: 0.7854 - val_loss: 0.4912 - val_acc: 0.7711
Epoch 40/1000
 - 671s - loss: 0.4574 - acc: 0.7857 - val_loss: 0.4921 - val_acc: 0.7671
Epoch 41/1000
 - 659s - loss: 0.4585 - acc: 0.7878 - val_loss: 0.4917 - val_acc: 0.7729
Epoch 42/1000
 - 671s - loss: 0.4557 - acc: 0.7921 - val_loss: 0.4969 - val_acc: 0.7711
Epoch 43/1000
 - 673s - loss: 0.4528 - acc: 0.7933 - val_loss: 0.4958 - val_acc: 0.7711
Epoch 44/1000
 - 674s - loss: 0.4509 - acc: 0.7930 - val_loss: 0.4998 - val_acc: 0.7664
Epoch 45/1000
 - 657s - loss: 0.4509 - acc: 0.7949 - val_loss: 0.4951 - val_acc: 0.7725
Epoch 46/1000
 - 665s - loss: 0.4470 - acc: 0.7968 - val_loss: 0.4982 - val_acc: 0.7649
Epoch 47/1000
 - 672s - loss: 0.4458 - acc: 0.7981 - val_loss: 0.4944 - val_acc: 0.7682
Epoch 48/1000
 - 679s - loss: 0.4385 - acc: 0.8017 - val_loss: 0.4985 - val_acc: 0.7675
Epoch 49/1000
 - 670s - loss: 0.4437 - acc: 0.8009 - val_loss: 0.4958 - val_acc: 0.7685
Epoch 50/1000
 - 665s - loss: 0.4408 - acc: 0.7991 - val_loss: 0.4926 - val_acc: 0.7714
Epoch 51/1000
 - 659s - loss: 0.4389 - acc: 0.8018 - val_loss: 0.5010 - val_acc: 0.7617
Epoch 52/1000
 - 659s - loss: 0.4361 - acc: 0.8009 - val_loss: 0.5024 - val_acc: 0.7592
Epoch 53/1000
 - 662s - loss: 0.4341 - acc: 0.8009 - val_loss: 0.4968 - val_acc: 0.7693
Epoch 54/1000
 - 658s - loss: 0.4327 - acc: 0.8046 - val_loss: 0.5051 - val_acc: 0.7588
Epoch 55/1000
 - 671s - loss: 0.4331 - acc: 0.8044 - val_loss: 0.4985 - val_acc: 0.7675
Epoch 56/1000
 - 669s - loss: 0.4302 - acc: 0.8075 - val_loss: 0.5022 - val_acc: 0.7635
Epoch 57/1000
 - 674s - loss: 0.4316 - acc: 0.8073 - val_loss: 0.4959 - val_acc: 0.7739
Epoch 58/1000
 - 673s - loss: 0.4292 - acc: 0.8046 - val_loss: 0.5020 - val_acc: 0.7639
Epoch 59/1000
 - 674s - loss: 0.4299 - acc: 0.8073 - val_loss: 0.4999 - val_acc: 0.7772
Epoch 60/1000
 - 670s - loss: 0.4249 - acc: 0.8071 - val_loss: 0.5081 - val_acc: 0.7624
Epoch 61/1000
 - 675s - loss: 0.4239 - acc: 0.8067 - val_loss: 0.4935 - val_acc: 0.7682
Epoch 62/1000
 - 667s - loss: 0.4163 - acc: 0.8139 - val_loss: 0.4913 - val_acc: 0.7682
Epoch 63/1000
 - 666s - loss: 0.4224 - acc: 0.8106 - val_loss: 0.5047 - val_acc: 0.7624
Epoch 64/1000
 - 657s - loss: 0.4231 - acc: 0.8092 - val_loss: 0.4953 - val_acc: 0.7689
Epoch 65/1000
 - 674s - loss: 0.4180 - acc: 0.8116 - val_loss: 0.4877 - val_acc: 0.7757
Epoch 66/1000
 - 679s - loss: 0.4155 - acc: 0.8152 - val_loss: 0.5058 - val_acc: 0.7613
Epoch 67/1000
 - 676s - loss: 0.4173 - acc: 0.8123 - val_loss: 0.5043 - val_acc: 0.7653
Epoch 68/1000
 - 683s - loss: 0.4205 - acc: 0.8099 - val_loss: 0.5070 - val_acc: 0.7660
Epoch 69/1000
 - 660s - loss: 0.4114 - acc: 0.8171 - val_loss: 0.5033 - val_acc: 0.7757
Epoch 70/1000
 - 667s - loss: 0.4156 - acc: 0.8137 - val_loss: 0.4897 - val_acc: 0.7750
Epoch 71/1000
 - 700s - loss: 0.4150 - acc: 0.8141 - val_loss: 0.5064 - val_acc: 0.7567
Epoch 72/1000
 - 661s - loss: 0.4075 - acc: 0.8191 - val_loss: 0.4996 - val_acc: 0.7675
Epoch 73/1000
 - 662s - loss: 0.4103 - acc: 0.8185 - val_loss: 0.4965 - val_acc: 0.7696
Epoch 74/1000
 - 680s - loss: 0.4069 - acc: 0.8184 - val_loss: 0.5152 - val_acc: 0.7610
Epoch 75/1000
 - 676s - loss: 0.4080 - acc: 0.8168 - val_loss: 0.5035 - val_acc: 0.7725
Epoch 76/1000
 - 668s - loss: 0.4055 - acc: 0.8212 - val_loss: 0.4987 - val_acc: 0.7675
Epoch 77/1000
 - 656s - loss: 0.4066 - acc: 0.8183 - val_loss: 0.5023 - val_acc: 0.7599
Epoch 78/1000
 - 667s - loss: 0.4073 - acc: 0.8217 - val_loss: 0.5047 - val_acc: 0.7700
Epoch 79/1000
 - 669s - loss: 0.4034 - acc: 0.8204 - val_loss: 0.4996 - val_acc: 0.7718
Epoch 80/1000
 - 706s - loss: 0.4022 - acc: 0.8219 - val_loss: 0.5104 - val_acc: 0.7559
Epoch 81/1000
 - 665s - loss: 0.3996 - acc: 0.8226 - val_loss: 0.4996 - val_acc: 0.7574
Epoch 82/1000
 - 658s - loss: 0.4025 - acc: 0.8183 - val_loss: 0.5077 - val_acc: 0.7556
Epoch 83/1000
 - 702s - loss: 0.4003 - acc: 0.8245 - val_loss: 0.4997 - val_acc: 0.7642
Epoch 84/1000
 - 683s - loss: 0.3985 - acc: 0.8227 - val_loss: 0.5014 - val_acc: 0.7682
Epoch 85/1000
 - 682s - loss: 0.3989 - acc: 0.8234 - val_loss: 0.5262 - val_acc: 0.7534
Epoch 86/1000

Process finished with exit code 1

似乎其中一个问题是过度拟合,但是我已经添加了大量的辍学来消除一些过度拟合。

下面是训练/验证的学习曲线图像。

https://imgur.com/a/ipBIy>

0 个答案:

没有答案