训练损耗随数据量的增加而增加,并且火车精度几乎保持不变

时间:2019-08-16 17:07:24

标签: tensorflow machine-learning deep-learning lstm

我是ML的新手。我正在尝试长时间使用时间序列数据来开发LSTM模型。我已经尝试了LSTM的几种体系结构,但是损失非常大,而且准确性较低。我的数据中有很多空值。我删除了至少有一个空值的任何序列(例如8分钟窗口)。我使用了代码形式https://machinelearningmastery.com/

import numpy as np
from numpy import array
import pandas as pd
import os
import matplotlib.pyplot as plt
import math
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from keras.optimizers import Adam   

在这里,我排除了所有具有空值的序列。排除任何空vlaues的序列后,它将寻找下一个序列

def split_datasetnormalized(dataset, timestamp):

    train_size = int(len(dataset) * 0.67)
    test_size = len(dataset) - train_size
    train, test = dataset[0:train_size, :], dataset[test_size:len(dataset), :]

    x = []
    y = []
    xp = []
    yp = []
    for i in range(len(train) - timestamp - 1):

        if len(train[i:i +
                     timestamp]) == timestamp and not np.isnan(train[i:i +
                                                                     timestamp +
                                                                     1]).any():
            trainx = train[i:i + timestamp, 0]
            trainy = train[i + timestamp:i + timestamp + 1, 0]
            x.append(trainx)
            y.append(trainy)
    for i in range(len(test) - timestamp - 1):

        if len(test[i:i +
                    timestamp]) == timestamp and not np.isnan(test[i:i +
                                                                   timestamp +
                                                                   1]).any():
            testx = test[i:i + timestamp]
            testy = test[i + timestamp:i + timestamp + 1]
            xp.append(testx)
            yp.append(testy)
    return array(x), array(y), array(xp), array(yp)


def lstmwindow(dfData, lags):
    database = split_datasetnormalized(dfData, lags)
    trainX = database[0].reshape(database[0].shape[0], 1, database[0].shape[1])
    trainY = (database[1].reshape(1, -1))[0]
    testX = database[2].reshape(database[2].shape[0], 1, database[2].shape[1])
    testY = (database[3].reshape(1, -1))[0]
    return trainX, trainY, testX, testY

尽管我有一个大文件(402k),但我尝试使用前20000行。如果我大大减少行数,则损失会减少,而准确性会提高。数据可以在https://gofile.io/?c=PoM9dM

上找到
trainData = 'data/train.csv'
look_back = 8
df = pd.read_csv(trainData, usecols=['tested'], nrows=20000)
dataset = df.values
dataset = dataset.astype('float32')

scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
trainX, trainY, testX, testY = lstmwindow(dataset, look_back)

我之所以使用学习率,是因为它能提供比默认设置更好的结果。

opt = Adam(lr=0.0000001, decay=.2)
model = Sequential()
model.add(LSTM(1028, input_shape=(1, look_back), return_sequences=True))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(256, return_sequences=True))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(64))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam', metrics=['mape', 'acc'])
model.fit(trainX, trainY, epochs=20, batch_size=64, verbose=1)


trainPredict = model.predict(trainX)
testPredict = model.predict(testX)

trainPredict = scaler.inverse_transform(trainPredict)
trainY = scaler.inverse_transform([trainY])
testPredict = scaler.inverse_transform(testPredict)
testY = scaler.inverse_transform([testY])
print (trainPredict, trainY)


trainScore = math.sqrt(mean_squared_error(trainY[0], trainPredict[:, 0]))
print('Train Score: %.2f RMSE' % (trainScore))
testScore = math.sqrt(mean_squared_error(testY[0], testPredict[:, 0]))
print('Test Score: %.2f RMSE' % (testScore))


plt.plot(trainY[0][:100])
plt.plot(trainPredict.reshape(-1, trainPredict.shape[0])[0][:100])
plt.show()

结果是。有时经过一个时期之后,损耗值逐渐增加,准确度很低。

432/12552 [====>.........................] - ETA: 32s - loss: 0.0059 - mean_absolute_percentage_error: 992116.1136 - acc: 0.0082 
 2496/12552 [====>.........................] - ETA: 32s - loss: 0.0060 - mean_absolute_percentage_error: 966677.6248 - acc: 0.0080
 2560/12552 [=====>........................] - ETA: 32s - loss: 0.0061 - mean_absolute_percentage_error: 963082.6779 - acc: 0.0086
 2624/12552 [=====>........................] - ETA: 31s - loss: 0.0061 - mean_absolute_percentage_error: 939593.1212 - acc: 0.0084
 2688/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 957549.8326 - acc: 0.0089
 2752/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 935281.5181 - acc: 0.0087
 2816/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 963213.3245 - acc: 0.0092
 2880/12552 [=====>........................] - ETA: 31s - loss: 0.0059 - mean_absolute_percentage_error: 941808.8309 - acc: 0.0090
 2944/12552 [======>.......................] - ETA: 31s - loss: 0.0059 - mean_absolute_percentage_error: 921335.0405 - acc: 0.0088
 3008/12552 [======>.......................] - ETA: 30s - loss: 0.0059 - mean_absolute_percentage_error: 920601.9731 - acc: 0.0090
 3072/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 901423.1130 - acc: 0.0088
 3136/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 901941.2332 - acc: 0.0089
 3200/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 883902.6478 - acc: 0.0088
 3264/12552 [======>.......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 887954.1915 - acc: 0.0089
 3328/12552 [======>.......................] - ETA: 29s - loss: 0.0059 - mean_absolute_percentage_error: 889670.6806 - acc: 0.0090
 3392/12552 [=======>......................] - ETA: 29s - loss: 0.0059 - mean_absolute_percentage_error: 891472.6347 - acc: 0.0091
 3456/12552 [=======>......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 907832.6322 - acc: 0.0093
 3520/12552 [=======>......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 891326.8646 - acc: 0.0091
 3584/12552 [=======>......................] - ETA: 28s - loss: 0.0061 - mean_absolute_percentage_error: 1068598.5278 - acc: 0.0098
 3648/12552 [=======>......................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1089488.1545 - acc: 0.0101
 3712/12552 [=======>......................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1070704.1027 - acc: 0.0100
 3776/12552 [========>.....................] - ETA: 28s - loss: 0.0061 - mean_absolute_percentage_error: 1052556.8863 - acc: 0.0098
 3840/12552 [========>.....................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1035014.4576 - acc: 0.0096
 3904/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1018047.2437 - acc: 0.0095
 3968/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1044726.5623 - acc: 0.0096
 4032/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1047885.8193 - acc: 0.0097
 4096/12552 [========>.....................] - ETA: 27s - loss: 0.0061 - mean_absolute_percentage_error: 1198321.3221 - acc: 0.0095
 4160/12552 [========>.....................] - ETA: 27s - loss: 0.0061 - mean_absolute_percentage_error: 1265580.6710 - acc: 0.0099
 4224/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1302111.3199 - acc: 0.0102
 4288/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1282676.9945 - acc: 0.0100
 4352/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1263814.2721 - acc: 0.0099
 4416/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1254763.8207 - acc: 0.0100
 4480/12552 [=========>....................] - ETA: 26s - loss: 0.0060 - mean_absolute_percentage_error: 1236838.8069 - acc: 0.0098
 4544/12552 [=========>....................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1244310.3624 - acc: 0.0099
 4608/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1227028.4810 - acc: 0.0098
 4672/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1210220.0377 - acc: 0.0096
 4736/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1230652.4426 - acc: 0.0097
 4800/12552 [==========>...................] - ETA: 25s - loss: 0.0059 - mean_absolute_percentage_error: 1214243.8915 - acc: 0.0096
 4864/12552 [==========>...................] - ETA: 24s - loss: 0.0059 - mean_absolute_percentage_error: 1198267.1610 - acc: 0.0095

1 个答案:

答案 0 :(得分:0)

您可能overfitting是这里的模型,这说明了准确性的下降和损失的增加。尝试参数调整。您的学习率可以更高。尝试添加dropout以避免过度拟合。

fn main() {
    let a: ValueType = 0x0u8.into();
}