如何使用tensorflow注意层?

时间:2020-06-27 19:37:52

标签: python tensorflow keras seq2seq

我正在尝试了解如何使用此处显示的tf.keras.layers.Attention

Tensorflow Attention Layer

我正在尝试将其与编码器解码器seq2seq模型一起使用。下面是我的代码:

encoder_inputs = Input(shape=(max_len_text,)) 
enc_emb = Embedding(x_voc_size, latent_dim,trainable=True)(encoder_inputs) 
encoder_lstm=LSTM(latent_dim, return_state=True, return_sequences=True) 
encoder_outputs, state_h, state_c= encoder_lstm(enc_emb) 

decoder_inputs = Input(shape=(max_len_summary,)) 
dec_emb_layer = Embedding(y_voc_size, latent_dim,trainable=True) 
dec_emb = dec_emb_layer(decoder_inputs) 

decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) 
decoder_outputs,decoder_fwd_state, decoder_back_state = decoder_lstm(dec_emb,initial_state=[state_h, state_c]) 

我的问题是,如何在此模型中使用keras中的给定Attention层?我听不懂他们的文件。

1 个答案:

答案 0 :(得分:0)

如果您使用的是RNN,我不建议您使用上述类。

在分析tf.keras.layers.Attention Github代码以更好地理解您的难题时,我可能遇到的第一行是-“该类适用于Dense或CNN网络,不适用于RNN网络”

我建议您将自己的seq写入seq模型,这可以用不到十二行代码来完成。例如:https://www.tensorflow.org/tutorials/text/nmt_with_attention

要编写自己的自定义关注层(基于您是否喜欢Bahdanau,Luong,Raffel,Yang等),也许这篇概述基本要素的帖子可能会有所帮助:Custom Attention Layer using in Keras