递归神经网络二进制分类

时间:2019-01-21 13:02:22

标签: python tensorflow machine-learning keras recurrent-neural-network

我可以访问100个人的数据框,以及他们在特定运动测试中的表现。此帧每人包含约25,000行,因为此人的表现被跟踪(大约)每厘米(10 ^ -2)。我们希望使用这些数据来预测二进制y标签,也就是说,是否有人运动障碍。

数据集的列和某些值如下:

'Person_ID', 'time_in_game', 'python_time', 'permutation_game, 'round', 'level', 'times_level_played_before', 'speed', 'costheta', 'y_label', 'gender', 'age_precise', 'ax_f', 'ay_f', 'az_f', 'acc', 'jerk'
1,            0.25,           1.497942e+09,  2,                 1,      'level_B', 1,                           0.8,    0.4655,    1,         [...]

通过每半秒使用一行,我将数据集减少到每人只有480行。

现在,我想使用递归神经网络来预测二进制y_label。

此代码提取用于输入数据X的costheta特征和用于输出Y的y标签。

X = []
Y = []

for ID in person_list:
    person_frame = df.loc[df['Person_ID'] == Person_ID]

    # costheta is a measurement of performance
    coslist = list(person_frame['costheta'])

    # extract y-label
    score = list(person_frame['y_label'].head(1))[0]

    X.append(coslist)
    Y.append(binary)

我使用0.2测试分割将数据分割为训练和测试数据。然后,我尝试使用Keras创建RNN,如下所示:

from keras import Sequential
from keras.layers import Embedding, LSTM, Dense, Dropout

embedding_size=32
model=Sequential()

# different_input_values are the set of possible input values
model.add(Embedding(different_input_values, embedding_size, input_length=480))
model.add(LSTM(1000))

# output is binary
model.add(Dense(1, activation='sigmoid'))
print(model.summary())

最后,我开始使用以下代码进行训练:

model.compile(loss='binary_crossentropy', 
             optimizer='adam', 
             metrics=['accuracy'])

batch_size = 64
num_epochs = 100

X_valid, y_valid = X_train[:batch_size], Y_train[:batch_size]
X_train2, y_train2 = X_train[batch_size:], Y_train[batch_size:]

model.fit(X_train2, y_train2, validation_data=(X_valid, y_valid), batch_size=batch_size, epochs=num_epochs).

但是,获得的准确度确实很低。根据批次大小,它在0.4到0.6之间变化。

  

12/12 [=============================]-13秒1秒/步-损失:0.6921-   acc:0.7500-val_loss:0.7069-val_acc:0.4219

我的问题是,通常来说,使用这样的复杂数据,如何有效地训练RNN。是否应该避免将数据减少到每人480行并将其保持在每人25,000行左右?诸如acc(游戏中的加速度)和jerk之类的多个指标会带来明显的准确性提升吗?人们可以改变和考虑哪些重大改进?

0 个答案:

没有答案