Keras:嵌入TimeDistributed

时间:2019-03-30 01:22:42

标签: keras

我正在使用this的实现和Hierarchical Attention Network进行文档分类。与作者等效,我想使用具有可训练参数的嵌入层。问题在于TimeDistributed层(第186行)将复制可训练的参数。这导致要训练的大量参数。在我正在工作的问题中,仅该层我就获得了500M以上的内存。如果我将可训练条件设置为False,则可训练参数的数量会降低到正常值。

我认为问题在于TimeDistributed层正在整个链上复制嵌入矩阵(应该这样做)。但是,如何设置模型以共享固定的嵌入矩阵?

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
main_input (InputLayer)         (None, 20, 40)       0                                            
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, 20, 512)      543574704   main_input[0][0]                 
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) (None, 20, 512)      1574912     time_distributed_1[0][0]         
__________________________________________________________________________________________________
attention_with_context_2 (Atten (None, 512)          263168      bidirectional_2[0][0]            
__________________________________________________________________________________________________
dense_out_dom (Dense)           (None, 21)           10773       attention_with_context_2[0][0]  

0 个答案:

没有答案