我正在使用this的实现和Hierarchical Attention Network进行文档分类。与作者等效,我想使用具有可训练参数的嵌入层。问题在于TimeDistributed
层(第186行)将复制可训练的参数。这导致要训练的大量参数。在我正在工作的问题中,仅该层我就获得了500M以上的内存。如果我将可训练条件设置为False
,则可训练参数的数量会降低到正常值。
我认为问题在于TimeDistributed
层正在整个链上复制嵌入矩阵(应该这样做)。但是,如何设置模型以共享固定的嵌入矩阵?
Layer (type) Output Shape Param # Connected to
==================================================================================================
main_input (InputLayer) (None, 20, 40) 0
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, 20, 512) 543574704 main_input[0][0]
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) (None, 20, 512) 1574912 time_distributed_1[0][0]
__________________________________________________________________________________________________
attention_with_context_2 (Atten (None, 512) 263168 bidirectional_2[0][0]
__________________________________________________________________________________________________
dense_out_dom (Dense) (None, 21) 10773 attention_with_context_2[0][0]