Keras中二进制矩阵输入的精度卡住训练

时间:2017-10-15 20:09:29

标签: python neural-network keras

我有一套约1500个功能和300个类。每个输入对象可以由该组1500中的一定数量的特征来描述,并且被分类为这300个类中的一个。 输入对象基于特征的存在而“矢量化”或“嵌入”。例如,一个对象具有25个特征,因此表示为具有25个1和1475 0的向量。类标签是单热编码的。 输入数据集大约有10k行,因此X_train设置形状为(10000,1500),Y_train为(10000,325),因此矩阵非常稀疏。 我正在尝试在Keras中构建一个多类分类器。 这是我的模特:

model = Sequential()
model.add(Dense(3000, input_dim=1500, activation='relu', kernel_initializer='he_uniform'))
model.add(Dropout(0.4))
model.add(Dense(2500, activation='relu', kernel_initializer='he_uniform'))
model.add(Dropout(0.4))
model.add(Dense(2000, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(1500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(1000, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(400, activation='relu'))
model.add(Dense(300, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy',metrics.top_k_categorical_accuracy])

问题是,即使在100-200-300个时期之后,准确度也不会高于0.4。尝试改变隐藏层和神经元的数量,并关闭辍学和任何正规化,但结果仍然是相同的。

Epoch 4/300
9450/9450 [==============================] - 27s - loss: 4.0560 - acc: 0.2440 - top_k_categorical_accuracy: 0.3423 - val_loss: 3.6102 - val_acc: 0.3235 - val_top_k_categorical_accuracy: 0.3987
Epoch 5/300
9450/9450 [==============================] - 27s - loss: 3.5148 - acc: 0.3233 - top_k_categorical_accuracy: 0.4218 - val_loss: 3.2821 - val_acc: 0.4015 - val_top_k_categorical_accuracy: 0.4377
Epoch 6/300
9450/9450 [==============================] - 27s - loss: 3.3581 - acc: 0.3585 - top_k_categorical_accuracy: 0.4401 - val_loss: 3.2520 - val_acc: 0.4053 - val_top_k_categorical_accuracy: 0.4339
Epoch 7/300
9450/9450 [==============================] - 27s - loss: 3.2331 - acc: 0.3898 - top_k_categorical_accuracy: 0.4460 - val_loss: 3.1321 - val_acc: 0.4196 - val_top_k_categorical_accuracy: 0.4358
Epoch 8/300
9450/9450 [==============================] - 27s - loss: 3.1762 - acc: 0.4003 - top_k_categorical_accuracy: 0.4477 - val_loss: 3.1258 - val_acc: 0.4215 - val_top_k_categorical_accuracy: 0.4377
Epoch 9/300
9450/9450 [==============================] - 27s - loss: 3.1389 - acc: 0.4086 - top_k_categorical_accuracy: 0.4513 - val_loss: 3.0792 - val_acc: 0.4272 - val_top_k_categorical_accuracy: 0.4405
Epoch 10/300
9450/9450 [==============================] - 27s - loss: 3.0873 - acc: 0.4224 - top_k_categorical_accuracy: 0.4505 - val_loss: 3.0815 - val_acc: 0.4263 - val_top_k_categorical_accuracy: 0.4396
Epoch 11/300
9450/9450 [==============================] - 27s - loss: 3.0869 - acc: 0.4242 - top_k_categorical_accuracy: 0.4539 - val_loss: 3.1044 - val_acc: 0.4272 - val_top_k_categorical_accuracy: 0.4377
Epoch 12/300
9450/9450 [==============================] - 28s - loss: 3.0760 - acc: 0.4260 - top_k_categorical_accuracy: 0.4528 - val_loss: 3.0833 - val_acc: 0.4263 - val_top_k_categorical_accuracy: 0.4405
Epoch 13/300
9450/9450 [==============================] - 27s - loss: 3.0815 - acc: 0.4236 - top_k_categorical_accuracy: 0.4528 - val_loss: 3.0685 - val_acc: 0.4282 - val_top_k_categorical_accuracy: 0.4396
Epoch 14/300
9450/9450 [==============================] - 27s - loss: 3.0963 - acc: 0.4207 - top_k_categorical_accuracy: 0.4517 - val_loss: 3.1324 - val_acc: 0.4196 - val_top_k_categorical_accuracy: 0.4377
Epoch 15/300
9450/9450 [==============================] - 27s - loss: 3.0927 - acc: 0.4197 - top_k_categorical_accuracy: 0.4530 - val_loss: 3.0870 - val_acc: 0.4272 - val_top_k_categorical_accuracy: 0.4396
Epoch 16/300
9450/9450 [==============================] - 27s - loss: 3.0795 - acc: 0.4257 - top_k_categorical_accuracy: 0.4530 - val_loss: 3.0834 - val_acc: 0.4301 - val_top_k_categorical_accuracy: 0.4415
Epoch 17/300
9450/9450 [==============================] - 27s - loss: 3.0657 - acc: 0.4266 - top_k_categorical_accuracy: 0.4533 - val_loss: 3.1244 - val_acc: 0.4253 - val_top_k_categorical_accuracy: 0.4405
Epoch 18/300
9450/9450 [==============================] - 27s - loss: 3.0784 - acc: 0.4262 - top_k_categorical_accuracy: 0.4551 - val_loss: 3.0882 - val_acc: 0.4272 - val_top_k_categorical_accuracy: 0.4443
Epoch 19/300
9450/9450 [==============================] - 27s - loss: 3.0540 - acc: 0.4303 - top_k_categorical_accuracy: 0.4525 - val_loss: 3.0817 - val_acc: 0.4282 - val_top_k_categorical_accuracy: 0.4405
Epoch 20/300
9450/9450 [==============================] - 27s - loss: 3.0815 - acc: 0.4255 - top_k_categorical_accuracy: 0.4516 - val_loss: 3.0910 - val_acc: 0.4244 - val_top_k_categorical_accuracy: 0.4405

如何提高输入数据的准确度或smth的任何想法都是错误的?

0 个答案:

没有答案