共享层训练模型后的Keras修剪层

时间:2020-07-30 23:32:14

标签: python tensorflow keras

我正在尝试训练搜索和项目编码器,这是我拥有的模型

input_search = Input(shape=(40,), dtype='int64', name='input_search')
input_title = Input(shape=(40,), dtype='int64', name='input_title')
input_desc = Input(shape=(40,), dtype='int64', name='input_desc')
input_brand = Input(shape=(40,), dtype='int64', name='input_brand')

embedding = Embedding(input_dim=20000, output_dim=50, input_length=40)
s_emb = embedding(input_search)
t_emb = embedding(input_title)
d_emb = embedding(input_desc)
b_emb = embedding(input_brand)

s = GlobalMaxPool1D()(s_emb)
t = GlobalMaxPool1D()(t_emb)
d = GlobalMaxPool1D()(d_emb)
b = GlobalMaxPool1D()(b_emb)

concat = concatenate([t, d, b])
concat = Dense(128)(concat)
s = Dense(128, name='vec')(s)
similarity = Lambda(cos_sim)([s, concat])

model = Model(inputs=[input_search, input_desc, input_brand, input_title], outputs=similarity)
_______________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_search (InputLayer)       [(None, 40)]         0                                            
__________________________________________________________________________________________________
input_title (InputLayer)        [(None, 40)]         0                                            
__________________________________________________________________________________________________
input_desc (InputLayer)         [(None, 40)]         0                                            
__________________________________________________________________________________________________
input_brand (InputLayer)        [(None, 40)]         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 40, 50)       1526100     input_search[0][0]               
                                                                 input_title[0][0]                
                                                                 input_desc[0][0]                 
                                                                 input_brand[0][0]                
__________________________________________________________________________________________________
global_max_pooling1d_5 (GlobalM (None, 50)           0           embedding_1[1][0]                
__________________________________________________________________________________________________
global_max_pooling1d_6 (GlobalM (None, 50)           0           embedding_1[2][0]                
__________________________________________________________________________________________________
global_max_pooling1d_7 (GlobalM (None, 50)           0           embedding_1[3][0]                
__________________________________________________________________________________________________
global_max_pooling1d_4 (GlobalM (None, 50)           0           embedding_1[0][0]                
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 150)          0           global_max_pooling1d_5[0][0]     
                                                                 global_max_pooling1d_6[0][0]     
                                                                 global_max_pooling1d_7[0][0]     
__________________________________________________________________________________________________
search (Dense)                  (None, 128)          6528        global_max_pooling1d_4[0][0]     
__________________________________________________________________________________________________
product (Dense)                 (None, 128)          19328       concatenate_5[0][0]              
__________________________________________________________________________________________________
lambda (Lambda)                 (None,)              0           search[0][0]                     
                                                                 product[0][0]                    
==================================================================================================
Total params: 1,551,956
Trainable params: 1,551,956
Non-trainable params: 0
__________________________________________________________________________________________________

项目具有3个功能-标题,描述和品牌。我想在所有4个输入之间使用相同的嵌入层,以在相同的向量空间中生成嵌入。但是,对于预测时间,我想删除该项目的3个输入,并使模型的输出成为名为“ vec”的密集层的嵌入向量,以存储特征向量。这是我尝试过的:

search_model = Model(model.inputs[0], model.layers[-2].output)
item_model = Model(inputs=model.inputs[1:], outputs=model.layers[-1].output)

但是我得到了错误

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_brand_1:0", shape=(None, 40), dtype=int64) at layer "input_brand". The following previous layers were accessed without issue: []

我是否可以在搜索和项目特征之间共享嵌入层权重,从而使我可以修剪某些层来推断时间?还是创建2个单独的嵌入层,其中1个通过搜索运行,另一个通过该商品的3个功能运行,但以某种方式使两层的权重保持不变?

0 个答案:

没有答案