训练多输出keras模型选择性分支

时间:2017-05-31 09:12:00

标签: keras

我有一个多输出Keras模型,其结构类似于:

s = some_shared_layers()(input)
non_trainable1 = Dense(trainable=False) (s) 
non_trainable2 = Dense(trainable=False) (s) 
trainable = Dense() (s) 

model = Model(input, outputs=[non_trainable1, non_trainable2, trainable])

我的模型首先计算前向传球并使用前2个输出来操纵输入。然后它计算另一个前向传递以获得第三个输出。

out1, out2,_ =model.predict(input_data) 
processed_data = foo(input_data, out1, out2) 
_,_, out3 = model.predict(processed_data)

如何调用model.fit()仅培训trainable图层?如果我排除其他输出的损失,Keras警告we will not be expecting any data to be passed to "non_trainable1" during training并将它们从计算图中排除。

是否有更好的方法来为此任务构建模型?

1 个答案:

答案 0 :(得分:0)

如果我理解正确,你根本就不需要那些图层,事实上你应该有两个模型,一个用于预测,另一个用于训练。

无法训练:

model1 = Model(input, [non_trainable1, non_trainable2])
#model 1 doesn't need to be compiled, since you won't train it    

可训练:

model2 = Model(input, trainable)
model2.compile(loss=onlyTheLossForTrainable)     

使用它们:

out1, out2 =model1.predict(input_data) 
processed_data = foo(input_data, out1, out2) 

model2.fit(processed_data, expected_outputs, ....)