Keras LSTM输入形状错误

时间:2017-08-30 07:21:38

标签: python keras lstm rnn

我正在尝试在Keras中创建简单的RNN,它将学习这个数据集:

x_train = [
    [0,0,0,1,-1,-1,1,0,1,0,...,0,1,-1],
    [-1,0,0,-1,-1,0,1,1,1,...,-1,-1,0],
    ...
    [1,0,0,1,1,0,-1,-1,-1,...,-1,-1,0]
]

其中1表示增加一个度量,-1表示减少,0表示度量没有变化。每个数组有83个项目,83个指标,每个数组的输出(标签)是一个分类数组,显示这些指标对单个指标的影响:

[[ 0.  0.  1.]
 [ 1.  0.  0.],
 [ 0.  0.  1.],
 ...
 [ 0.  0.  1.],
 [ 1.  0.  0.]]

我在以下代码中使用了KerasLSTM

def train(x, y, x_test, y_test):
    x_train = np.array(x)
    y_train = np.array(y)
    print x_train.shape
    y_train = to_categorical(y_train, 3)
    model = Sequential()
    model.add(LSTM(128,input_dim=83, input_length=3))
    model.add(Dropout(0.5))
    model.add(Dense(3, activation='softmax'))
    opt = optimizers.SGD(lr=0.1, decay=1e-2)
    model.compile(loss='categorical_crossentropy',
            optimizer=opt,
            metrics=['accuracy'])
    model.fit(x_train, y_train, batch_size=128, nb_epoch=200)

print x_train.shape的输出为(1618, 83),当我运行我的代码时,我收到此错误:

Traceback (most recent call last):
  File "temp.py", line 171, in <module>
    load()
  File "temp.py", line 166, in load
    train(x, y, x_test, y_test)
  File "temp.py", line 63, in train
    model.fit(x_train, y_train, batch_size=128, nb_epoch=200)
  File "/usr/local/lib/python2.7/dist-packages/keras/models.py", line 652, in fit
    sample_weight=sample_weight)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1038, in fit
    batch_size=batch_size)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 963, in _standardize_user_data
    exception_prefix='model input')
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 100, in standardize_input_data
    str(array.shape))
Exception: Error when checking model input: expected lstm_input_1 to have 3 dimensions, but got array with shape (1618, 83)

我不想使用Embedding,并希望将input_shape添加到LSTM图层。

1 个答案:

答案 0 :(得分:1)

LSTM是一个循环层,意味着输入数据必须是三维的,这对应于二维输入形状。实际上,这意味着数据必须具有形状(num_samples, timesteps, features),输入形状必须为(timesteps, features)

在您的情况下,您缺少数据和输入形状的时间步长维度。