Keras / Python - 如果RNN是有状态的,则必须提供完整的input_shape(包括批量大小)

时间:2016-03-11 23:06:22

标签: python keras

我试图实施有状态的RNN,但它一直在问我一个完整的input_shape(包括批量大小)"。在input_shape和input_batch_size参数中尝试了不同的东西,但似乎没有人工作。任何人都能发光吗?

代码:

model=Sequential()      
model.add(SimpleRNN(init='uniform',output_dim=80,input_dim=len(pred_frame.columns),stateful=True,batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),input_shape=(len(pred_frame.index),len(pred_frame.columns))))
model.add(Dense(output_dim=200,input_dim=len(pred_frame.columns),init="glorot_uniform"))
model.add(Dense(output_dim=1))
model.compile(loss="mse", class_mode='scalar', optimizer="sgd")
model.fit(X=predictor_train, y=target_train, batch_size=len(pred_frame.index),show_accuracy=True)

回溯:

File "/Users/file.py", line 1483, in Pred
model.add(SimpleRNN(init='uniform',output_dim=80,input_dim=len(pred_frame.columns),stateful=True,batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),input_shape=(len(pred_frame.index),len(pred_frame.columns))))
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 194, in __init__
super(SimpleRNN, self).__init__(**kwargs)
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 97, in __init__
super(Recurrent, self).__init__(**kwargs)
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 43, in __init__
self.set_input_shape((None,) + tuple(kwargs['input_shape']))
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 141, in set_input_shape
self.build()
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 199, in build
self.reset_states()
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 221, in reset_states
'(including batch size).')
Exception: If a RNN is stateful, a complete input_shape must be provided (including batch size).

2 个答案:

答案 0 :(得分:3)

您只需提供batch_input_shape =参数,输入_shape参数。此外,为避免输入形状错误,请确保训练数据大小是batch_size的倍数。最后,如果您使用验证拆分,则必须确保两个拆分也是batch_size的倍数。

# ensure data size is a multiple of batch_size
data_size=data_size-data_size%batch_size
# ensure validation splits are multiples of batch_size
increment=float(batch_size)/len(data_size)
val_split=float(int(val_split/(increment))) * increment

答案 1 :(得分:0)

在您对SimpleRNN的定义中,移除input_diminput_shape,请设置:

batch_input_shape = (Number_Of_sequences, Size_Of_Each_Sequence,
                     Shape_Of_Element_In_Each_Sequence) 

batch_input_shape应该是一个长度至少为3的元组。

如果您逐个传递序列,请设置:

Number_Of_sequences = 1

如果序列的大小未修复,请设置:

Size_Of_Each_Sequence = None