我正在尝试使用LSTM进行时间序列预测。我训练了一个模型,该模型对训练的损失很小,但是当我尝试预测训练数据本身时,它提供的值与实际值相差甚远。这是我的代码:
train_data = conf_world.values.reshape(-1,1)
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(train_data)
scaled_train_data = scaler.transform(train_data)
def split_sequence(sequence, n_steps):
X, y = list(), list()
for i in range(len(sequence)):
# find the end of this pattern
end_ix = i + n_steps
# check if we are beyond the sequence
if end_ix > len(sequence)-1:
break
# gather input and output parts of the pattern
seq_x, seq_y = sequence[i:end_ix], sequence[end_ix]
X.append(seq_x)
y.append(seq_y)
return np.array(X), np.array(y)
n_input = 5
train_x, train_y = split_sequence(scaled_train_data,n_input)
n_features =1
train_x = train_x.reshape((train_x.shape[0],train_x.shape[1],n_features))
lstm_model = Sequential()
lstm_model.add(LSTM(input_shape=(n_input, n_features),units=50,activation='relu',return_sequences=True))
lstm_model.add(LSTM(100))
lstm_model.add(Dense(1))
lstm_model.compile(optimizer = 'adam', loss = 'mean_squared_error')
lstm_model.fit(train_x,train_y, epochs = 200)
predicted_data = []
batch = scaled_train_data[:n_input].copy()
current_batch = batch.reshape((1, n_input, n_features))
lstm_pred = lstm_model.predict(current_batch)[0]
for i in range(len(train_data)-5):
lstm_pred = lstm_model.predict(current_batch)[0]
print(current_batch,lstm_pred)
predicted_data.append(lstm_pred)
current_batch = np.append(current_batch[:,1:,:],[[lstm_pred]],axis=1)
prediction = pd.Series(data=scaler.inverse_transform(predicted_data).reshape(1,-1)[0].round().astype(int),index=conf_world[5:].index)
prediction
我的实际数据是:
2020-01-27 2927
2020-01-28 5578
2020-01-29 6166
2020-01-30 8234
2020-01-31 9927
...
2020-04-07 1426096
2020-04-08 1511104
2020-04-09 1595350
2020-04-10 1691719
2020-04-11 1771514
Length: 76, dtype: int64
但是预测数据来了:
2020-01-27 8890
2020-01-28 9517
2020-01-29 10686
2020-01-30 12272
2020-01-31 14369
...
2020-04-07 96617
2020-04-08 99337
2020-04-09 102411
2020-04-10 105903
2020-04-11 109897
Length: 76, dtype: int64
我在做什么错?我应该对模型进行哪些改进?我是时间序列问题的新手,所以请帮助我