我正在使用Keras训练一个涉及多个分支的稍微复杂的模型。由于问题的结构,分别训练这些片段没有意义,但是我不想将网络主分支的损失函数应用于与主分支合并的另一个分支(该分支有自己的损失函数和输出)。我想知道在Functional API中是否有任何方法可以实现这一目标。
这是一个玩具模型定义,突出了我遇到的基本问题。图层本身并不重要,重要的是模型的结构:
#Utility for subtracting some tensors
def diff(two_tensors):
x, y = two_tensors
return x - y
#Two inputs
in_1 = Input((128,))
in_2 = Input((128,))
#Main branch: Does something to the first input
branch_1 = Dense(128, activation='relu')(in_1)
branch_1 = Dense(128, activation='relu')(branch_1)
#Auxilliary classifier definition:
classifier = Sequential()
classifier.add(Dense(128, activation='relu'))
classifier.add(Dense(1, activation='linear'))
#This model asserts a confidence, and we'll use a sigmoid to actually classify
#The classifier takes the second input and the result of the main branch
pred_a = classifier(branch_1)
pred_b = classifier(in_2)
class_out = Activation('sigmoid')
class_a = class_out(pred_a)
class_b = class_out(pred_b)
#We calculate the difference between these two confidences in a lambda
pred_diff = Lambda(diff)([pred_a, pred_b])
#The main model uses this difference for another classification
branch_join = concatenate([pred_diff, branch_1], axis=-1)
main_output = Dense(20, activation='softmax')(branch_join)
#The model outputs the aux classifier's choices and the main prediction
model = Model(inputs=[in_1, in_2], outputs=[main_output, class_a, class_b])
losses = { 'main_output': 'sparse_categorical_crossentropy'
, 'class_a': 'binary_crossentropy'
, 'class_b': 'binary_crossentropy'
}
我只想在模型的中间训练分类器,因为它仅对两个输入具有分类精度。但是,由于它已连接到主分支,因此我认为主输出的损失也将通过这些层传播。在许多类似的情况下(例如简单的GAN),答案将是分别训练分类器并在训练端到端系统时冻结分类器。但是,这不适用于我的用例,我想知道是否可以阻止main_output的损失向后传播到分类器模型中,同时仍在其自身的输出上进行训练。