问题合并LSTM Seq2Seq模型中的两个层用于Q& A用例

时间:2018-06-09 16:20:10

标签: keras deep-learning lstm

我正在尝试基于bAbI Task 8 example构建Q& A模型,但我无法将两个输入图层合并到一个图层中。这是我目前的模型架构:

story_input = Input(shape=(story_maxlen,vocab_size), name='story_input')
story_input_proc = Embedding(vocab_size, latent_dim, name='story_input_embed', input_length=story_maxlen)(story_input)
story_input_proc = Reshape((latent_dim,story_maxlen), name='story_input_reshape')(story_input_proc)

query_input = Input(shape=(query_maxlen,vocab_size), name='query_input')
query_input_proc = Embedding(vocab_size, latent_dim, name='query_input_embed', input_length=query_maxlen)(query_input)
query_input_proc = Reshape((latent_dim,query_maxlen), name='query_input_reshape')(query_input_proc)

story_query = dot([story_input_proc, query_input_proc], axes=(1, 1), name='story_query_merge')

encoder = LSTM(latent_dim, return_state=True, name='encoder')
encoder_output, state_h, state_c = encoder(story_query)
encoder_output = RepeatVector(3, name='encoder_3dim')(encoder_output) 
encoder_states = [state_h, state_c]

decoder = LSTM(latent_dim, return_sequences=True, name='decoder')(encoder_output, initial_state=encoder_states)
answer_output = Dense(vocab_size, activation='softmax', name='answer_output')(decoder)

model = Model([story_input, query_input], answer_output)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

,这是model.summary()的输出

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
story_input (InputLayer)        (None, 358, 38)      0                                            
__________________________________________________________________________________________________
query_input (InputLayer)        (None, 5, 38)        0                                            
__________________________________________________________________________________________________
story_input_embed (Embedding)   (None, 358, 64)      2432        story_input[0][0]                
__________________________________________________________________________________________________
query_input_embed (Embedding)   (None, 5, 64)        2432        query_input[0][0]                
__________________________________________________________________________________________________
story_input_reshape (Reshape)   (None, 64, 358)      0           story_input_embed[0][0]          
__________________________________________________________________________________________________
query_input_reshape (Reshape)   (None, 64, 5)        0           query_input_embed[0][0]          
__________________________________________________________________________________________________
story_query_merge (Dot)         (None, 358, 5)       0           story_input_reshape[0][0]        
                                                                 query_input_reshape[0][0]        
__________________________________________________________________________________________________
encoder (LSTM)                  [(None, 64), (None,  17920       story_query_merge[0][0]          
__________________________________________________________________________________________________
encoder_3dim (RepeatVector)     (None, 3, 64)        0           encoder[0][0]                    
__________________________________________________________________________________________________
decoder (LSTM)                  (None, 3, 64)        33024       encoder_3dim[0][0]               
                                                                 encoder[0][1]                    
                                                                 encoder[0][2]                    
__________________________________________________________________________________________________
answer_output (Dense)           (None, 3, 38)        2470        decoder[0][0]                    
==================================================================================================
Total params: 58,278
Trainable params: 58,278
Non-trainable params: 0
__________________________________________________________________________________________________

其中vocab_size = 38,story_maxlen = 358,query_maxlen = 5,latent_dim = 64,批量大小= 64。

当我尝试训练此模型时,我收到错误:

Input to reshape is a tensor with 778240 values, but the requested shape has 20480

以下是这两个值的公式:

input_to_reshape = batch_size * latent_dim * query_maxlen * vocab_size

requested_shape = batch_size * latent_dim * query_maxlen

我在哪里

我相信错误信息是说输入query_input_reshape层的张量的形状是(?,5,38,64),但是期待形状的张量(?,5,64)(见上面的公式),但我可能错了。

当我将Reshape的target_shape输入更改为3D(即Reshape((latent_dim,query_maxlen,vocab_size))时,我收到错误total size of new array must be unchanged,这对我没有任何意义,因为输入是3D。你会认为Reshape((latent_dim,query_maxlen))会给我这个错误,因为它会将3D张量变成2D张量,但编译得很好,所以我不知道那里发生了什么。

我使用Reshape的唯一原因是因为我需要将两个张量合并为LSTM编码器的输入。当我试图摆脱Reshape图层时,我只是在尝试编译模型时遇到尺寸不匹配错误。上面的模型架构至少编译但我无法训练它。

有人可以帮我弄清楚我如何合并story_input和query_input图层?谢谢!

0 个答案:

没有答案