检查输入时出错:预期lstm_1_input有3个维度,但是有形状的数组(3653,3)

时间:2018-01-26 12:36:04

标签: r keras recurrent-neural-network

我正在尝试用R中的keras学习LSTM。我无法完全理解keras中使用的约定。

我的数据集如下所示,前3列被视为输入,最后一列被视为输出。

enter image description here

基于此,我正在尝试按如下方式构建无状态LSTM:

model %>%
  layer_lstm(units = 1024, input_shape = c(1, 3), return_sequences = T ) %>%  
  layer_lstm(units = 1024, return_sequences = F) %>% 
  # using linear activation on last layer, as output is needed in real number
  layer_dense(units = 1, activation = "linear")

model %>% compile(loss = 'mse', optimizer = 'rmsprop')

模型如下所示

Layer (type)       Output Shape       Param #             
=====================================================
lstm_1 (LSTM)      (None, 1, 1024)    4210688             
_____________________________________________________
lstm_2 (LSTM)      (None, 1024)       8392704             
_____________________________________________________
dense_3 (Dense)    (None, 1)          1025                
=====================================================
Total params: 12,604,417
Trainable params: 12,604,417
Non-trainable params: 0    
_____________________________________________________

我正在尝试按如下方式训练模型:

history <- model %>% fit(dt[,1:3], dt[,4], epochs=50, shuffle=F)

但是,当我尝试执行代码时,我收到以下错误。

  

py_call_impl中的错误(callable,dots $ args,dots $ keywords):     ValueError:检查输入时出错:预期lstm_1_input有3个维度,但是有形状的数组(3653,3)

不确定我在这里缺少什么。

更新:在浏览互联网后,我似乎需要将数据集重塑为3维(batchsize,timestep,#feature)数组。但是,我没有使用任何批次,因此不确定如何重塑我的数据。

更新于28.01.2018:这对我有用。我在我的第一个LSTM图层中使用了input_shape = c(1, 3),因为我有3个功能,而且我没有使用任何批处理。因此,我最终还是使用以下函数重塑了我的数据:

reshapeDt <- function(data){ # data is the original train matrix (training dataset)
  rows <- nrow(data)
  cols <- ncol(data)-1

  dt <- array(dim=c(rows, 1, cols))
  for(i in 1:rows){
    dt[i,1,] <- data[i,1:cols]
  }
  dt
}

这意味着对fit的调用如下所示:

model %>% fit(reshapeDt(dt), dt[,4], epochs=50, shuffle=F)

这意味着dim(reshapeDt(dt))会返回number_of_rows_in_dt 1 3

1 个答案:

答案 0 :(得分:1)

LSTM图层的输入形状应为(batch, time_steps, features)

您必须整理数据才能拥有此形状。

您似乎只有一个序列,有6个时间步长和3个功能。所以,input_shape=(6,3)。实际上,您可以将(None,3)用于可变长度的序列。

您的输入数组dt应具有形状(1,length,3)