我是ML的新手。我正在尝试长时间使用时间序列数据来开发LSTM模型。我已经尝试了LSTM的几种体系结构,但是损失非常大,而且准确性较低。我的数据中有很多空值。我删除了至少有一个空值的任何序列(例如8分钟窗口)。我使用了代码形式https://machinelearningmastery.com/
import numpy as np
from numpy import array
import pandas as pd
import os
import matplotlib.pyplot as plt
import math
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from keras.optimizers import Adam
在这里,我排除了所有具有空值的序列。排除任何空vlaues的序列后,它将寻找下一个序列
def split_datasetnormalized(dataset, timestamp):
train_size = int(len(dataset) * 0.67)
test_size = len(dataset) - train_size
train, test = dataset[0:train_size, :], dataset[test_size:len(dataset), :]
x = []
y = []
xp = []
yp = []
for i in range(len(train) - timestamp - 1):
if len(train[i:i +
timestamp]) == timestamp and not np.isnan(train[i:i +
timestamp +
1]).any():
trainx = train[i:i + timestamp, 0]
trainy = train[i + timestamp:i + timestamp + 1, 0]
x.append(trainx)
y.append(trainy)
for i in range(len(test) - timestamp - 1):
if len(test[i:i +
timestamp]) == timestamp and not np.isnan(test[i:i +
timestamp +
1]).any():
testx = test[i:i + timestamp]
testy = test[i + timestamp:i + timestamp + 1]
xp.append(testx)
yp.append(testy)
return array(x), array(y), array(xp), array(yp)
def lstmwindow(dfData, lags):
database = split_datasetnormalized(dfData, lags)
trainX = database[0].reshape(database[0].shape[0], 1, database[0].shape[1])
trainY = (database[1].reshape(1, -1))[0]
testX = database[2].reshape(database[2].shape[0], 1, database[2].shape[1])
testY = (database[3].reshape(1, -1))[0]
return trainX, trainY, testX, testY
尽管我有一个大文件(402k),但我尝试使用前20000行。如果我大大减少行数,则损失会减少,而准确性会提高。数据可以在https://gofile.io/?c=PoM9dM
上找到trainData = 'data/train.csv'
look_back = 8
df = pd.read_csv(trainData, usecols=['tested'], nrows=20000)
dataset = df.values
dataset = dataset.astype('float32')
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
trainX, trainY, testX, testY = lstmwindow(dataset, look_back)
我之所以使用学习率,是因为它能提供比默认设置更好的结果。
opt = Adam(lr=0.0000001, decay=.2)
model = Sequential()
model.add(LSTM(1028, input_shape=(1, look_back), return_sequences=True))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(256, return_sequences=True))
model.add(LSTM(128, return_sequences=True))
model.add(LSTM(64))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam', metrics=['mape', 'acc'])
model.fit(trainX, trainY, epochs=20, batch_size=64, verbose=1)
trainPredict = model.predict(trainX)
testPredict = model.predict(testX)
trainPredict = scaler.inverse_transform(trainPredict)
trainY = scaler.inverse_transform([trainY])
testPredict = scaler.inverse_transform(testPredict)
testY = scaler.inverse_transform([testY])
print (trainPredict, trainY)
trainScore = math.sqrt(mean_squared_error(trainY[0], trainPredict[:, 0]))
print('Train Score: %.2f RMSE' % (trainScore))
testScore = math.sqrt(mean_squared_error(testY[0], testPredict[:, 0]))
print('Test Score: %.2f RMSE' % (testScore))
plt.plot(trainY[0][:100])
plt.plot(trainPredict.reshape(-1, trainPredict.shape[0])[0][:100])
plt.show()
结果是。有时经过一个时期之后,损耗值逐渐增加,准确度很低。
432/12552 [====>.........................] - ETA: 32s - loss: 0.0059 - mean_absolute_percentage_error: 992116.1136 - acc: 0.0082
2496/12552 [====>.........................] - ETA: 32s - loss: 0.0060 - mean_absolute_percentage_error: 966677.6248 - acc: 0.0080
2560/12552 [=====>........................] - ETA: 32s - loss: 0.0061 - mean_absolute_percentage_error: 963082.6779 - acc: 0.0086
2624/12552 [=====>........................] - ETA: 31s - loss: 0.0061 - mean_absolute_percentage_error: 939593.1212 - acc: 0.0084
2688/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 957549.8326 - acc: 0.0089
2752/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 935281.5181 - acc: 0.0087
2816/12552 [=====>........................] - ETA: 31s - loss: 0.0060 - mean_absolute_percentage_error: 963213.3245 - acc: 0.0092
2880/12552 [=====>........................] - ETA: 31s - loss: 0.0059 - mean_absolute_percentage_error: 941808.8309 - acc: 0.0090
2944/12552 [======>.......................] - ETA: 31s - loss: 0.0059 - mean_absolute_percentage_error: 921335.0405 - acc: 0.0088
3008/12552 [======>.......................] - ETA: 30s - loss: 0.0059 - mean_absolute_percentage_error: 920601.9731 - acc: 0.0090
3072/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 901423.1130 - acc: 0.0088
3136/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 901941.2332 - acc: 0.0089
3200/12552 [======>.......................] - ETA: 30s - loss: 0.0060 - mean_absolute_percentage_error: 883902.6478 - acc: 0.0088
3264/12552 [======>.......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 887954.1915 - acc: 0.0089
3328/12552 [======>.......................] - ETA: 29s - loss: 0.0059 - mean_absolute_percentage_error: 889670.6806 - acc: 0.0090
3392/12552 [=======>......................] - ETA: 29s - loss: 0.0059 - mean_absolute_percentage_error: 891472.6347 - acc: 0.0091
3456/12552 [=======>......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 907832.6322 - acc: 0.0093
3520/12552 [=======>......................] - ETA: 29s - loss: 0.0060 - mean_absolute_percentage_error: 891326.8646 - acc: 0.0091
3584/12552 [=======>......................] - ETA: 28s - loss: 0.0061 - mean_absolute_percentage_error: 1068598.5278 - acc: 0.0098
3648/12552 [=======>......................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1089488.1545 - acc: 0.0101
3712/12552 [=======>......................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1070704.1027 - acc: 0.0100
3776/12552 [========>.....................] - ETA: 28s - loss: 0.0061 - mean_absolute_percentage_error: 1052556.8863 - acc: 0.0098
3840/12552 [========>.....................] - ETA: 28s - loss: 0.0060 - mean_absolute_percentage_error: 1035014.4576 - acc: 0.0096
3904/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1018047.2437 - acc: 0.0095
3968/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1044726.5623 - acc: 0.0096
4032/12552 [========>.....................] - ETA: 27s - loss: 0.0060 - mean_absolute_percentage_error: 1047885.8193 - acc: 0.0097
4096/12552 [========>.....................] - ETA: 27s - loss: 0.0061 - mean_absolute_percentage_error: 1198321.3221 - acc: 0.0095
4160/12552 [========>.....................] - ETA: 27s - loss: 0.0061 - mean_absolute_percentage_error: 1265580.6710 - acc: 0.0099
4224/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1302111.3199 - acc: 0.0102
4288/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1282676.9945 - acc: 0.0100
4352/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1263814.2721 - acc: 0.0099
4416/12552 [=========>....................] - ETA: 26s - loss: 0.0061 - mean_absolute_percentage_error: 1254763.8207 - acc: 0.0100
4480/12552 [=========>....................] - ETA: 26s - loss: 0.0060 - mean_absolute_percentage_error: 1236838.8069 - acc: 0.0098
4544/12552 [=========>....................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1244310.3624 - acc: 0.0099
4608/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1227028.4810 - acc: 0.0098
4672/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1210220.0377 - acc: 0.0096
4736/12552 [==========>...................] - ETA: 25s - loss: 0.0060 - mean_absolute_percentage_error: 1230652.4426 - acc: 0.0097
4800/12552 [==========>...................] - ETA: 25s - loss: 0.0059 - mean_absolute_percentage_error: 1214243.8915 - acc: 0.0096
4864/12552 [==========>...................] - ETA: 24s - loss: 0.0059 - mean_absolute_percentage_error: 1198267.1610 - acc: 0.0095
答案 0 :(得分:0)
您可能overfitting是这里的模型,这说明了准确性的下降和损失的增加。尝试参数调整。您的学习率可以更高。尝试添加dropout以避免过度拟合。
fn main() {
let a: ValueType = 0x0u8.into();
}