keras LSTM预测不佳

时间:2018-10-29 14:38:54

标签: python tensorflow keras lstm

我正在尝试根据某个序列预测值(我有5个值,例如1,2,3,4,5,并希望下一个值-6)。我正在为此使用LSTM keras。

创建训练数据:

import numpy as np 
from keras.models import Sequential
from keras.layers import LSTM,Dense
a = [float(i) for i in range(1,100)]
a = np.array(a)

data_train = a[:int(len(a)*0.9)]
data_test = a[int(len(a)*0.9):]

x = 5
y = 1
z = 0

train_x = []
train_y = []
for i in data_train:
    t = data_train[z:x]
    r = data_train[x:x+y]
    if len(r) == 0:
        break
    else:
        train_x.append(t)
        train_y.append(r)
        z = z + 1
        x = x+1

train_x = np.array(train_x)
train_y = np.array(train_y)

x = 5
y = 1
z = 0

test_x = []
test_y = []
for i in data_test:
    t = data_test[z:x]
    r = data_test[x:x+y]
    if len(r) == 0:
        break
    else:
        test_x.append(t)
        test_y.append(r)
        z = z + 1
        x = x+1

test_x = np.array(test_x)
test_y = np.array(test_y)

print(train_x.shape,train_y.shape)
print(test_x.shape,test_y.shape)

将其转换为LSTM形状:

train_x_1 = train_x.reshape(train_x.shape[0],len(train_x[0]),1)
train_y_1 = train_y.reshape(train_y.shape[0],1)
test_x_1 = test_x.reshape(test_x.shape[0],len(test_x[0]),1)
test_y_1 = test_y.reshape(test_y.shape[0],1)


print(train_x_1.shape, train_y_1.shape)
print(test_x_1.shape, test_y_1.shape)

构建和训练模型:

model = Sequential()
model.add(LSTM(32,return_sequences = False,input_shape=(trein_x_1.shape[1],1)))
model.add(Dense(1))

model.compile(loss='mse',  optimizer='adam', metrics=['accuracy'])
history = model.fit(train_x_1,
                    train_y_1,
                    epochs=20,
                    shuffle=False, 
                    batch_size=1, 
                    verbose=2, 
                    validation_data=(test_x_1,test_y_1))

但是我得到了一个非常糟糕的结果,有人可以向我解释我做错了吗。

pred = model.predict(test_x_1)
for i,a in enumerate(pred):
    print(pred[i],test_y_1[i])
[89.71895] [95.]
[89.87877] [96.]
[90.03465] [97.]
[90.18714] [98.]
[90.337006] [99.]

Thenks。

1 个答案:

答案 0 :(得分:0)

您希望网络从您用于训练的数据中推断。神经网络are not good at this。您可以尝试对数据进行规范化,以便不再使用例如相对值而不是绝对值进行推断。这当然会使这个例子变得微不足道。