功能性API连接问题(CNN输出到LSTM输入)

时间:2019-09-07 11:33:22

标签: keras lstm

我正在尝试CNN的输出以在Keras中的seq2seq模型的编码器中馈送LSTM。但是LSTM没有收到CNN的输出。

有人有主意吗? 非常感谢。

from keras.layers import Input, LSTM
from keras.models import Model

import keras
regulariser = None

num_encoder_input_features=128
hidden_neurons=64

def cnn_encode(input_tensor):
    c1 = keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(201, 201, 1),data_format='channels_last')(input_tensor)
    c1 = keras.layers.MaxPooling2D((2, 2))(c1)
    f1 = keras.layers.Flatten()(c1)
    f1 = keras.layers.Dense(128)(f1)
    return f1

input_tensor = keras.layers.Input(shape=( 201, 201, 1))
cnn_output_tensor = cnn_encode(input_tensor)

encoder = LSTM(hidden_neurons, return_state=True)
encoder_outputs, state_h, state_c = encoder(cnn_output_tensor)
encoder_states = [state_h, state_c]

decoder_inputs = Input(shape=(None, 201, 201, 1))
decoder_lstm = LSTM(hidden_neurons, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)

model = Model([input_tensor, decoder_inputs], decoder_outputs)

回溯(大多数/model/05_Seq2Seq_Convolution_Deconvolution.py”,第22行,在     encoder_outputs,state_h,state_c =编码器(cnn_output_tensor)   调用中的文件“ /anaconda3/envs/dev_tf1.13/lib/python3.6/site-packages/keras/layers/recurrent.py”,第532行     返回超级(RNN,自我)。调用(输入,** kwargs)   在调用中的文件“ /anaconda3/envs/dev_tf1.13/lib/python3.6/site-packages/keras/engine/base_layer.py”,第414行     self.assert_input_compatibility(输入)   在assert_input_compatibility中的文件“ /anaconda3/envs/dev_tf1.13/lib/python3.6/site-packages/keras/engine/base_layer.py”第311行     str(K.ndim(x))) ValueError:输入0与lstm_1层不兼容:预期ndim = 3,找到ndim = 2

0 个答案:

没有答案