Keras序列:使用三个参数指定输入形状

时间:2018-10-12 00:32:51

标签: python keras

我有一个具有以下设置的数据框:

import numpy as np

X = np.random.rand(100, 20, 3)

这里有100个时间片,20个观测值和每个观测值3个属性。

我试图弄清楚如何将以上数据传递给以下Keras序列:

from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras

# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]*3)
input_shape = (look_back, n_features, 3)
output_shape = n_features

def loss(y_true, y_pred):
  return keras.losses.mean_squared_error(y_true, y_pred)

model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')

运行此抛出:

  

ValueError:输入0与lstm_23层不兼容:预期   ndim = 3,找到的ndim = 4

有人知道我如何重塑X并将其传递给模型吗?任何建议都会有所帮助!

1 个答案:

答案 0 :(得分:0)

这似乎使事情发生了变化:

from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras

# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]) * 3
input_shape = (look_back, n_features)
output_shape = n_features

def loss(y_true, y_pred):
  return keras.losses.mean_squared_error(y_true, y_pred)

model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True))
model.add(LSTM(lstm_cells, stateful=stateful))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')

然后可以按以下方式划分训练数据:

# build training data
train_x = []
train_y = []
n_time = int(X.shape[0])
n_obs = int(X.shape[1])
n_attrs = int(X.shape[2])

# note we flatten the last dimension
for i in range(look_back, n_time-1, 1):
  train_x.append( X[i-look_back:i].reshape(look_back, n_obs * n_attrs ) )
  train_y.append( X[i+1].ravel() )

train_x = np.array(train_x)
train_y = np.array(train_y)

然后可以训练玩具模型:

model.fit(train_x, train_y, epochs=10, batch_size=10)