我首先使用seq2seq +注意机制构建了聊天机器人,我实现了没有注意层,但结果却取得了70%的良好结果
现在,我试图提高精度,因为我在seq2seq编码器中添加了注意层-decoder模型
我在Keras中实现的所有功能
当我在seq2seq模型中添加关注层时,出现索引错误:列表索引超出范围
按照我的程序进行操作
# encoder_input_data
tr_que = tokenizer.texts_to_sequences(final_questions)
max_que_len = max([len(x) for x in tr_que])
que_pad = pad_sequences(tr_que, maxlen=max_que_len, padding='post')
encoder_input_data = np.array(que_pad)
print('Encoder input shape :',encoder_input_data.shape)
# decoder_input_data
tr_ans = tokenizer.texts_to_sequences(final_answers)
max_ans_len = max([len(x) for x in tr_ans])
ans_pad = pad_sequences(tr_ans, maxlen=max_ans_len, padding='post' )
decoder_input_data = np.array(ans_pad)
print('Decoder input shape :',decoder_input_data.shape)
# decoder_output_data
tr_ans = tokenizer.texts_to_sequences(final_answers)
for i in range(len(tr_ans)) :
tr_ans[i] = tr_ans[i][1:]
ans_pad = pad_sequences(tr_ans, maxlen=max_ans_len, padding='post' )
cat_ans = keras.utils.to_categorical(ans_pad, vocab_size)
decoder_output_data = np.array(cat_ans)
print('Decoder output shape :',decoder_output_data.shape)
output :
Encoder input shape : (5806, 22)
Decoder input shape : (5806, 38)
Decoder output shape : (5806, 38, 5661)
在上述代码中,代码段是seq2seq模型的输入和输出
现在,从下面的参考链接中,导入了一个关注层
https://towardsdatascience.com/light-on-math-ml-attention-with-keras-dc8dbc1fad39
根据以上参考资料,我正在尝试将注意力层应用到我的seq2seq模型中,如下所述
#Encoder inputs
Embedding_layer = Embedding(vocab_size, 300, mask_zero=True)
encoder_inputs = Input(shape=(None,), name='encoder_input_layer')
encoder_embedding = Embedding_layer(encoder_inputs)
encoder_outputs , state_h , state_c = LSTM(1024 , return_state=True, name='encoder_layer')(encoder_embedding)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,), name='decoder_input_layer')
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_embedding = Embedding_layer(decoder_inputs)
decoder_lstm = LSTM(1024, return_state=True, return_sequences=True, name='decoder_layer')
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
#Attention layer
#https://towardsdatascience.com/light-on-math-ml-attention-with-keras-dc8dbc1fad39
#Inputs to the attention layer are encoder_outputs and decoder_outputs
#we can get this pre-computed attention class "AttentionLayer" from above reference link
attention_layer = AttentionLayer(name='attention_layer')
attention_output = attention_layer([encoder_outputs, decoder_outputs])
#Concatenate the attention_output and decoder_outputs as an input to the softmax layer.
attention_decoder_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attention_output])
#Dense layer
decoder_dense = TimeDistributed(Dense(vocab_size, activation='softmax', name = 'softmax_layer'))
#decoder outputs
#attention_decoder_input to the dense layer
output = decoder_dense(attention_decoder_input)
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], output)
#compiling the model
model.compile(optimizer='adam', loss='categorical_crossentropy')
#model summary
model.summary()
应用后,我将遇到以下错误
--------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-29-af4228e23d00> in <module>()
19 #Inputs to the attention layer are encoder_outputs and decoder_outputs
20 attention_layer = AttentionLayer(name='attention_layer')
---> 21 attention_output = attention_layer([encoder_outputs, decoder_outputs])
22
23 #Concatenate the attention_output and decoder_outputs as an input to the softmax layer.
3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_shape.py in __getitem__(self, key)
868 return self._dims[key].value
869 else:
--> 870 return self._dims[key]
871 else:
872 if isinstance(key, slice):
IndexError: list index out of range
谁能帮助我我做错了什么
预先感谢