我目前正在研究基于LSTM的研究问题。但是,如下面将要演示的那样,在keras中使用RNN时,我遇到了上述错误。
我使用TF版本1.12.0 和keras 2.2.4。
这似乎适用于LSTMCell之类的单元格,但不适用于UGRNNCell。。不知道如何解决此问题。
cell1=tf.contrib.rnn.UGRNNCell(64)
cell2=tf.contrib.rnn.UGRNNCell(64)
cell3=tf.contrib.rnn.UGRNNCell(64)
cell4=tf.contrib.rnn.UGRNNCell(64)
这是我的模型:
model = Sequential()
model.add(RNN(cell1, input_shape=(train_X.shape[1:]),return_sequences=True))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(RNN(cell2,return_sequences=True))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(RNN(cell3, return_sequences=True))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(RNN(cell4,return_sequences=False))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
这会导致上述错误,但是当将单元格替换为LSTMCell时,不会出现任何问题。
我希望它可以在任何类型的单元中无缝运行。