LSTM培训损失和Val损失没有改变

时间:2020-06-02 19:21:06

标签: python tensorflow keras deep-learning lstm

我一直在尝试使用张量流keras创建LSTM RNN,以便仅基于Datetime和lat / long来预测某人是否在开车(二进制分类)。但是,当我训练网络时,损耗和val_loss并没有太大变化。

我将经/纬度转换为介于-1和1之间的x,y,z坐标。我还使用Datetime提取了是否是周末以及一天中的哪个时段(上午/下午/晚上)

以下是数据示例(格式有点奇怪):

                 trip_id weekday period_of_day  x     y        z        mode_cat
datetime    id                          
2011-08-27 06:13:01 20  1   0   2         0.650429  0.043524    0.758319    1
2011-08-27 06:13:02 20  1   0   2         0.650418  0.043487    0.758330    1
2011-08-27 06:13:03 20  1   0   2         0.650421  0.043490    0.758328    1
2011-08-27 06:13:04 20  1   0   2         0.650427  0.043506    0.758322    1
2011-08-27 06:13:05 20  1   0   2         0.650438  0.043516    0.758312    1

这是构建网络的代码:

single_step_model = tf.keras.models.Sequential()
single_step_model.add(tf.keras.layers.LSTM(512, return_sequences=True,
                                           input_shape=x_train_single.shape[-2:]))
single_step_model.add(tf.keras.layers.Dropout(0.4))
single_step_model.add(tf.keras.layers.Dense(128, activation='tanh'))
single_step_model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
single_step_model.compile(optimizer=opt, loss='binary_crossentropy',
                          metrics=['accuracy'])

我尝试了各种不同的学习率,批处理大小,时代,辍学,隐藏层数,单元数,所有这些都遇到了这个问题。

我也查看了我的数据,发现损失和val_loss等于训练/验证数据的百分比,即该数据集的行驶次数/总行数。这意味着我的网络总是在预测相同的结果。

以下是每个时期的训练和验证损失数据:

Epoch 1/100
1410/1410 [==============================] - 775s 550ms/step - loss: 0.6942 - binary_accuracy: 0.5273 - val_loss: 0.6909 - val_binary_accuracy: 0.5380
Epoch 2/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6911 - binary_accuracy: 0.5352 - val_loss: 0.6904 - val_binary_accuracy: 0.5380
Epoch 3/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6906 - binary_accuracy: 0.5374 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 4/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6905 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 5/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 6/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6904 - val_binary_accuracy: 0.5380
Epoch 7/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 8/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 9/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 10/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 11/100
1410/1410 [==============================] - 775s 550ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 12/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 13/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5377 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 14/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6904 - binary_accuracy: 0.5374 - val_loss: 0.6903 - val_binary_accuracy: 0.5379
Epoch 15/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5377 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 16/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 17/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 18/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 19/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380

这是因为我的功能/数据集中没有足够的信息可供我的网络学习吗?还是网络本身有问题?我还能尝试什么?请告知。

0 个答案:

没有答案