调用LSTM模型的预测函数时出现输入形状错误

时间:2020-10-18 00:42:26

标签: python keras deep-learning lstm

我已经安装了lstm模型。每个x和y变量有100个观测值。我使用了80个值来训练模型,并使用20个值来评估模型。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout

input_features=1
time_stamp=100
out_features=1
x_input=np.random.random((time_stamp,input_features))
y_input=np.random.random((time_stamp,out_features))

train_x=x_input[0:80,:]
test_x=x_input[80:,:]

train_y=y_input[0:80,:]
test_y=y_input[80:,:]

然后我将数据重新整形,然后再将其输入LSTM函数。(例如:用于训练x :(样本,时间步长,特征)=(1,80,1))

dataXtrain = train_x.reshape(1,80, 1)
dataXtest = test_x.reshape(1,20,1)
dataYtrain = train_y.reshape(1,80,1)
dataYtest = test_y.reshape(1,20,1)
dataXtrain.shape
(1, 80, 1)

然后我能够使用以下代码行成功拟合模型:

model = Sequential()
model.add(LSTM(20,activation = 'relu', return_sequences = True,input_shape=(dataXtrain.shape[1], 
dataXtrain.shape[2])))
model.add(Dense(1))
model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit(dataXtrain, dataYtrain, epochs=100, batch_size=10, verbose=1)

但是,当我预测测试数据的模型时,就会出现此错误。

y_pred = model.predict(dataXtest)
Error when checking input: expected lstm_input to have shape (80, 1) but got array with shape (20, 1)

有人可以帮我弄清楚这里出什么问题吗?

谢谢

1 个答案:

答案 0 :(得分:1)

似乎问题出在数据准备上。我认为您应该划分样本(而不是时间步长)以训练和测试数据,并且训练和测试样本的形状应相同且类似于(None, time-steps, features)

由于只有一个样本具有100个观测值(时间步长),因此可以将数据划分为包含小时间步长序列的样本。例如:

n_samples = 20
input_features = 1
time_stamp = 100
out_features = 1
x_input = np.random.random((1, time_stamp, input_features))
y_input = np.random.random((1, time_stamp, out_features))

new_time_stamp = time_stamp//n_samples
x_input = x_input.reshape(n_samples, new_time_stamp, input_features)
y_input = y_input.reshape(n_samples, new_time_stamp, out_features)

dataXtrain = x_input[:16,...]
dataXtest = x_input[16:,...]

dataYtrain = y_input[:16,...]
dataYtest = y_input[16:,...]

或者,您可以收集更多的数据样本,每个样本包含100个时间步长(取决于您的应用程序和现有数据)。

您还可以查看thisthis,它们在Keras中使用LSTM进行了全面的解释。