加载预先训练的Keras模型,并与分支模型合并以创建多输入

时间:2019-11-05 13:06:53

标签: keras lstm

我已经使用Keras功能API预训练了顺序模型。

现在,我需要加载模型,移除顶部并将其与分支侧模型连接起来,并训练新的完整模型。

我有兴趣保留主模型的第一个预训练层。

我认为问题是要设置正确的输入和输出,以便模型正确连接和编译。我认为,当前主要模型的新输入层与主要模型的其余部分之间存在脱节。

预训练模型中的层数是动态的。这是加载模型并进行更改的动机,而不是使用config和weights重新配置完整模型中的预训练层。我还没有找到一种可行的方法,可以使用config和weights硬编码完整模型的配置,以适应不断变化的预训练层数。

更新:为了更清楚地说明问题所在,我需要设置一个连接图。错误消息为“ ValueError:图形已断开连接:无法获取层” Pretrain_input”上的张量Tensor(“ Pretrain_input:0”,shape =(?, 4,105),dtype = float32)的值。可以顺利访问以下先前的图层:[]'

##### MAIN MODEL #####

# loading pretrained model
main_model = load_model('models/main_model.h5')


# saving the output layer for reconstruction later
config_output_layer = main_model.layers[-1].get_config()
weights_output_layer = main_model.layers[-1].get_weights()


# removing the first and last layer (not sure if I need to remove the first layer, but I do this as I need an explicit 'entry point' for later concatenation with branching model later and re-compiling )
main_model.layers.pop(0)
main_model.layers.pop(-1)


# new first layer, for input later
main_input = Input(
    shape=(x_train_main.shape[1], x_train_main.shape[2]),
    name='Main_input')


# reconstructing last layer
main_output = Dense.from_config(config_last_layer)(main_model.output)


# re-defining the main model
new_main_model = Model(inputs=main_input, outputs=main_output)


 ##### BRANCHING MODEL #####   


branch_visible = Input(
    shape=(x_train_branch.shape[1], x_train_branch.shape[2]),
    name='Branch_input')


branch_hidden_0 = LSTM(
    units=units_lstm_layer,
    return_sequences=True,
    name='Branch_hidden_0'
)(branch_visible)


branch_dense = Dense(
    units=units_dense_layer,
    name='Branch_dense'
)(branch_hidden_0)


##### CONCAT THE MAIN AND THE BRANCH #####

concatenated_output = Concatenate(axis=-1)([main_output, branch_dense])


activation_layer = Dense(
    units=units_activation_layer,
    activation=activation,
    name='Activation'
)(concatenated_output)


final_model = Model(inputs=[main_visible, branch_visible], outputs=activation_layer)

0 个答案:

没有答案