我正在尝试掌握分层注意力网络(HAN)的概念,我在网上找到的大多数代码或多或少与此处的代码相似:https://medium.com/jatana/report-on-text-classification-using-cnn-rnn-han-f0e887214d5f:
embedding_layer=Embedding(len(word_index)+1,EMBEDDING_DIM,weights=[embedding_matrix],
input_length=MAX_SENT_LENGTH,trainable=True)
sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32', name='input1')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm)
review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32', name='input2')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(len(macronum), activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)
我的问题是:这里的输入层代表什么?我猜测input1代表用嵌入层包装的句子,但是在这种情况下,input2是什么?它是sendEncoder的输出吗?在这种情况下,它应该是浮点数,或者如果它是嵌入单词的另一层,那么它也应该被嵌入层包裹。
答案 0 :(得分:1)
HAN模型按层次结构处理文本:它需要将文档拆分为句子(这就是input2
的形状为(MAX_SENTS,MAX_SENT_LENGTH)
的原因);然后使用sentEncoder
模型独立处理每个句子(这就是input1
的形状为(MAX_SENT_LENGTH,)
的原因),最后将所有编码的句子一起处理。
因此,在您的代码中,整个模型都存储在model
中,其输入层为input2
,您将获得文件,该文件已被拆分为句子,并且其单词已被整数编码(它与嵌入层兼容)。另一个输入层属于sentEncoder
模型,该模型在model
内部使用(并非直接由您使用):
review_encoder = TimeDistributed(sentEncoder)(review_input)
答案 1 :(得分:1)
马苏德的回答是正确的,但我将用自己的话在这里重写:
所以input2更是模型输入的代理。