有人可以解释在for循环内发生的代码吗? LSTM预测

时间:2020-05-13 10:38:01

标签: python lstm

因此,我正在研究一种机器学习模型,并且遇到了这段代码。

n_input=12
n_features=1
pred_list = []
batch = train[-n_input:].reshape(1,n_input,n_features)

for i in range(n_input):
    pred_list.append(model.predict(batch)[0]) #predict one value
    batch = np.append(batch[:,1:,:],[[pred_list[i]]],axis=1)#append to the end of the batch list

我了解到,批量处理时,它会使用-n_input索引值直到由“:”表示的结尾,然后将数据帧重塑为(1,n_input,n_features),在这种情况下为(1,12,1)。

为了进行预测,使用了for循环,它循环了12次,等于我们的n_inputs,n_inputs是我要预测的未来周期数。

这就是我开始困惑的地方,我不太了解for循环中的代码。有人可以解释循环内的代码吗?感谢您的阅读。

相关文章:https://medium.com/swlh/a-quick-example-of-time-series-forecasting-using-long-short-term-memory-lstm-networks-ddc10dc1467d

1 个答案:

答案 0 :(得分:0)

假设您有5个值的列表

[1,2,3,4,5]

您要预测第十个值。

该代码的第一行进行单个预测6。

该代码的第二行将其添加到上一个列表中(batch [:,1:,:]指定该批处理现在将增加一个值,[[pred_list [i]]]正在添加最近的预测到最后)。因此,您最终得到:

[2,3,4,5,6]

然后,您将使用新批次重新开始循环以预测第7个值。