绘图破折号中的绘图预测

时间:2020-11-12 19:31:59

标签: python plotly plotly-dash plotly-python

我正在尝试进行超出测试数据的时间序列预测。我遵循了以下教程以了解现在的情况(https://medium.com/swlh/a-quick-example-of-time-series-forecasting-using-long-short-term-memory-lstm-networks-ddc10dc1467d)。

现在当我可视化数据时我得到了 enter image description here

但是我只想显示预测值,而不是整个数据集。因此,类似于在底部添加一个滑块来显示预测,而不必放大。因此,滑块将使用户可以选择2020/10-2021/10的日期。

我还想在图表下方的文本框中而不是仅在图表上显示预测值(所停留的位置)。因此,每次将鼠标悬停在某个点上时,预测文本上的y值也会更新。

这是我现在拥有的代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from statsmodels.tools.eval_measures import rmse
from sklearn.preprocessing import MinMaxScaler
from keras.preprocessing.sequence import TimeseriesGenerator
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
import warnings
warnings.filterwarnings("ignore")

df = pd.read_csv('AirPassengers.csv')


df.Month = pd.to_datetime(df.Month)
df = df.set_index("Month")

scaler = MinMaxScaler()
scaler.fit(train)
train = scaler.transform(train)
n_input = 12
n_features = 1
generator = TimeseriesGenerator(train, train, length=n_input, batch_size=76)

model = Sequential()
model.add(LSTM(200, activation='relu', input_shape=(n_input, n_features)))
model.add(Dropout(0.15))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
model.fit_generator(generator,epochs=30)

pred_list = []

batch = train[-n_input:].reshape((1, n_input, n_features))

for i in range(n_input):   
    pred_list.append(model.predict(batch)[0]) 
    batch = np.append(batch[:,1:,:],[[pred_list[i]]],axis=1)

from pandas.tseries.offsets import DateOffset
add_dates = [dd.index[-1] + DateOffset(months=x) for x in range(0,13) ]
future_dates = pd.DataFrame(index=add_dates[1:],columns=dd.columns)

df_predict = pd.DataFrame(scaler.inverse_transform(pred_list),
                          index=future_dates[-n_input:].index, columns=['Prediction'])

df_proj = pd.concat([dd,df_predict], axis=1)

plot_data = [
    go.Scatter(
        x=df_proj.index,
        y=df_proj['Gauteng'],
        name='actual'
    ),
    go.Scatter(
        x=df_proj.index,
        y=df_proj['Prediction'],
        name='prediction'
    )
]

plot_layout = go.Layout(
        title='Time Series Prediction'
    )
fig = go.Figure(data=plot_data, layout=plot_layout)

import plotly as ply
ply.offline.plot(fig)

可以在https://github.com/gianfelton/12-Month-Forecast-With-LSTM/blob/master/AirPassengers.csv上找到日期集

提前谢谢

0 个答案:

没有答案