我一直在尝试使用张量流keras创建LSTM RNN,以便仅基于Datetime和lat / long来预测某人是否在开车(二进制分类)。但是,当我训练网络时,损耗和val_loss并没有太大变化。
我将经/纬度转换为介于-1和1之间的x,y,z坐标。我还使用Datetime提取了是否是周末以及一天中的哪个时段(上午/下午/晚上)
以下是数据示例(格式有点奇怪):
trip_id weekday period_of_day x y z mode_cat
datetime id
2011-08-27 06:13:01 20 1 0 2 0.650429 0.043524 0.758319 1
2011-08-27 06:13:02 20 1 0 2 0.650418 0.043487 0.758330 1
2011-08-27 06:13:03 20 1 0 2 0.650421 0.043490 0.758328 1
2011-08-27 06:13:04 20 1 0 2 0.650427 0.043506 0.758322 1
2011-08-27 06:13:05 20 1 0 2 0.650438 0.043516 0.758312 1
这是构建网络的代码:
single_step_model = tf.keras.models.Sequential()
single_step_model.add(tf.keras.layers.LSTM(512, return_sequences=True,
input_shape=x_train_single.shape[-2:]))
single_step_model.add(tf.keras.layers.Dropout(0.4))
single_step_model.add(tf.keras.layers.Dense(128, activation='tanh'))
single_step_model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
single_step_model.compile(optimizer=opt, loss='binary_crossentropy',
metrics=['accuracy'])
我尝试了各种不同的学习率,批处理大小,时代,辍学,隐藏层数,单元数,所有这些都遇到了这个问题。
我也查看了我的数据,发现损失和val_loss等于训练/验证数据的百分比,即该数据集的行驶次数/总行数。这意味着我的网络总是在预测相同的结果。
以下是每个时期的训练和验证损失数据:
Epoch 1/100
1410/1410 [==============================] - 775s 550ms/step - loss: 0.6942 - binary_accuracy: 0.5273 - val_loss: 0.6909 - val_binary_accuracy: 0.5380
Epoch 2/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6911 - binary_accuracy: 0.5352 - val_loss: 0.6904 - val_binary_accuracy: 0.5380
Epoch 3/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6906 - binary_accuracy: 0.5374 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 4/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6905 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 5/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 6/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6904 - val_binary_accuracy: 0.5380
Epoch 7/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 8/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 9/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 10/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 11/100
1410/1410 [==============================] - 775s 550ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 12/100
1410/1410 [==============================] - 775s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 13/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5377 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 14/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6904 - binary_accuracy: 0.5374 - val_loss: 0.6903 - val_binary_accuracy: 0.5379
Epoch 15/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6903 - binary_accuracy: 0.5377 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 16/100
1410/1410 [==============================] - 774s 549ms/step - loss: 0.6904 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 17/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 18/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5375 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
Epoch 19/100
1410/1410 [==============================] - 773s 548ms/step - loss: 0.6903 - binary_accuracy: 0.5376 - val_loss: 0.6903 - val_binary_accuracy: 0.5380
这是因为我的功能/数据集中没有足够的信息可供我的网络学习吗?还是网络本身有问题?我还能尝试什么?请告知。