如何在Keras中使用编码器解码器对不同长度的输出序列进行建模?

时间:2019-06-14 15:26:40

标签: keras seq2seq encoder-decoder

我正在尝试在Keras中为视频构建编码器-解码器体系结构。我有一个卷积LSTM,如果我具有相同大小的输入和输出,则编码器和解码器可以正常工作,但是,在输出大小与输入长度不同的情况下,我无法使其工作。

来自另一个stackoverflow问题的代码本身,尽管与自动编码器有关。我现在希望有一个用于目标序列的解码器。

下面的代码工作正常,但是如何更改最终输出,解码后的输出时间不长48个步长,例如说20个。

input_seq = Input(shape=(48,32,16,1))  

a = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True)(input_seq)
a = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True )(a)

b = MaxPooling3D((2,2,2), padding='same')(a)


c = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True)(b)
c = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same',return_sequences=True)(c)

encoded = MaxPooling3D((2,2,2), padding='same', name="encoder")(c)


d = ConvLSTM2D(40, (3, 3), activation='relu', padding='same',return_sequences=True )(encoded)
d = ConvLSTM2D(40, (3, 3), activation='relu', padding='same', return_sequences=True)(d)

e= UpSampling3D((2, 2,2))(d)

##Skip connection
#merge_one = concatenate([b, e])

f = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True) (e)#(e)
f = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True) (f)#(e)

g = UpSampling3D((2, 2,2))(f)    

decoded = Conv3D(1, (3, 3, 3), activation='sigmoid', padding='same')(g)# (merge_two)

model = Model(input_seq, decoded)
model.compile(optimizer='adadelta', loss='binary_crossentropy')


model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 48, 32, 16, 1)     0         
_________________________________________________________________
conv_lst_m2d_17 (ConvLSTM2D) (None, 48, 32, 16, 40)    59200     
_________________________________________________________________
conv_lst_m2d_18 (ConvLSTM2D) (None, 48, 32, 16, 40)    115360    
_________________________________________________________________
max_pooling3d_3 (MaxPooling3 (None, 24, 16, 8, 40)     0         
_________________________________________________________________
conv_lst_m2d_19 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
conv_lst_m2d_20 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
encoder (MaxPooling3D)       (None, 12, 8, 4, 40)      0         
_________________________________________________________________
conv_lst_m2d_21 (ConvLSTM2D) (None, 12, 8, 4, 40)      115360    
_________________________________________________________________
conv_lst_m2d_22 (ConvLSTM2D) (None, 12, 8, 4, 40)      115360    
_________________________________________________________________
up_sampling3d_5 (UpSampling3 (None, 24, 16, 8, 40)     0         
_________________________________________________________________
conv_lst_m2d_23 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
conv_lst_m2d_24 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
up_sampling3d_6 (UpSampling3 (None, 48, 32, 16, 40)    0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 48, 32, 16, 1)     1081      
=================================================================
Total params: 867,801
Trainable params: 867,801
Non-trainable params: 0
_________________________________________________________________

0 个答案:

没有答案