我有2个带有6个不同标签的信号数据序列。我想根据信号序列输入来预测它们的标签。我已经弄清楚,当您对数据序列进行分类时,LSTM是可行的方法。我试图建立一个网络,但我的准确性始终小于20%。另外,当我打印出预测(我的最后一层是softmax)时,它们彼此非常接近,我想这意味着它每次都会将其映射到相同的预测。
我的数据形状为(312,400,2)312个样本,400个时间步长和2个特征。我使用get_dummies()函数对标签进行了热编码。
层是这样的:
model = Sequential()
model.add(LSTM(100,input_shape=(400,2),dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(64,activation='relu'))
model.add(Dense(6,activation='softmax'))
时代:
Epoch 1/25
209/209 [==============================] - 13s 63ms/step - loss: 1.7931 - acc: 0.1531 - val_loss: 1.7916 - val_acc: 0.1456
Epoch 2/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7923 - acc: 0.1675 - val_loss: 1.7918 - val_acc: 0.1748
Epoch 3/25
209/209 [==============================] - 8s 38ms/step - loss: 1.7929 - acc: 0.1675 - val_loss: 1.7919 - val_acc: 0.1650
Epoch 4/25
209/209 [==============================] - 8s 38ms/step - loss: 1.7924 - acc: 0.1100 - val_loss: 1.7920 - val_acc: 0.1553
Epoch 5/25
209/209 [==============================] - 8s 38ms/step - loss: 1.7917 - acc: 0.1962 - val_loss: 1.7921 - val_acc: 0.1650
Epoch 6/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7926 - acc: 0.1292 - val_loss: 1.7921 - val_acc: 0.1650
Epoch 7/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7914 - acc: 0.1818 - val_loss: 1.7922 - val_acc: 0.1553
Epoch 8/25
209/209 [==============================] - 9s 43ms/step - loss: 1.7922 - acc: 0.2010 - val_loss: 1.7922 - val_acc: 0.1553
Epoch 9/25
209/209 [==============================] - 9s 42ms/step - loss: 1.7916 - acc: 0.1675 - val_loss: 1.7922 - val_acc: 0.1553
Epoch 10/25
209/209 [==============================] - 8s 41ms/step - loss: 1.7918 - acc: 0.1866 - val_loss: 1.7923 - val_acc: 0.1553
Epoch 11/25
209/209 [==============================] - 8s 41ms/step - loss: 1.7908 - acc: 0.2010 - val_loss: 1.7923 - val_acc: 0.1456
Epoch 12/25
209/209 [==============================] - 8s 40ms/step - loss: 1.7918 - acc: 0.2010 - val_loss: 1.7923 - val_acc: 0.1553
Epoch 13/25
209/209 [==============================] - 8s 40ms/step - loss: 1.7906 - acc: 0.1627 - val_loss: 1.7923 - val_acc: 0.1553
Epoch 14/25
209/209 [==============================] - 8s 40ms/step - loss: 1.7909 - acc: 0.1579 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 15/25
209/209 [==============================] - 9s 42ms/step - loss: 1.7914 - acc: 0.1675 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 16/25
209/209 [==============================] - 9s 43ms/step - loss: 1.7912 - acc: 0.2201 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 17/25
209/209 [==============================] - 8s 40ms/step - loss: 1.7916 - acc: 0.1675 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 18/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7899 - acc: 0.1722 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 19/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7908 - acc: 0.2153 - val_loss: 1.7924 - val_acc: 0.1456
Epoch 20/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7905 - acc: 0.1866 - val_loss: 1.7924 - val_acc: 0.1650
Epoch 21/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7901 - acc: 0.1770 - val_loss: 1.7924 - val_acc: 0.1650
Epoch 22/25
209/209 [==============================] - 8s 38ms/step - loss: 1.7913 - acc: 0.1722 - val_loss: 1.7924 - val_acc: 0.1650
Epoch 23/25
209/209 [==============================] - 8s 39ms/step - loss: 1.7915 - acc: 0.1818 - val_loss: 1.7924 - val_acc: 0.1650
Epoch 24/25
209/209 [==============================] - 9s 41ms/step - loss: 1.7910 - acc: 0.2105 - val_loss: 1.7924 - val_acc: 0.1553
Epoch 25/25
209/209 [==============================] - 9s 43ms/step - loss: 1.7904 - acc: 0.2105 - val_loss: 1.7925 - val_acc: 0.1553
当我打印出预测时,它们就像那些
array([[0.16541722, 0.1673543 , 0.16738486, 0.16682732, 0.16618429,
0.16683201],
[0.16414133, 0.16915058, 0.16742292, 0.16625686, 0.16690722,
0.16612107],
[0.16567224, 0.1668862 , 0.16726284, 0.1661307 , 0.16756196,
0.16648607],
...,
[0.165552 , 0.16795571, 0.16799515, 0.16348934, 0.16906977,
0.16593806],
[0.16314913, 0.16983458, 0.16802336, 0.16656826, 0.16621879,
0.1662058 ],
[0.16357513, 0.16757166, 0.16752186, 0.16805767, 0.16549116,
0.1677825 ]], dtype=float32)
我现在正在使用Adam优化器,但是我已经尝试使用带有裁剪和许多其他优化器的SGD。我也尝试过在0.1到0.000001之间增加或减少学习率。
所有功能均按标准缩放。我很难理解网络出了什么问题?是数据还是数据形状?它是网络的体系结构吗?我无能为力。感谢您的帮助。