自定义损失函数,它依赖于keras中的另一个神经网络

时间:2018-11-27 14:30:02

标签: python keras

我对keras有一个“我该怎么做”的问题:

假设我有第一个神经网络,比如说NNa,它有4个输入(x,y,z,t),已经被训练。 如果我有第二个神经网络,例如NNb,它的损失函数取决于第一个神经网络。

NNb customLossNNb的自定义损失函数使用固定网格(x,y,z)调用NNa的预测,而只需修改最后一个变量t。

在伪python代码中,我想训练第二个NN:NNb:

grid=np.mgrid[0:10:1,0:10:1,0:10:1].reshape(3,-1).T

Y[:,0]=time
Y[:,1]=something

def customLossNNb(NNa,grid):
     def diff(y_true,y_pred): 
         for ii in range(y_true.shape[0]):
               currentInput=concatenation of grid and y_true[ii,0]
               toto[ii,:]=NNa.predict(currentInput)
               #some stuff with toto
         return #...
     return diff

然后

NNb.compile(loss=customLossNNb(NNa,K.variable(grid)),optimizer='Adam')
NNb.fit(input,Y)

实际上,引起我麻烦的那一行是currentInput=concatenation of grid and y_true[ii,0]

我尝试使用K.variable(grid)将网格作为张量发送到customLossNNb。但是我无法在损失函数中定义一个新的张量,例如CurrentY,其形状为(grid.shape[0],1),且填充y[ii,0] ie 当前t)然后将gridcurrentY连接起来以构建currentInput

有什么想法吗?

谢谢

2 个答案:

答案 0 :(得分:1)

您可以使用keras的功能API将自定义损失函数包括在图形中。在这种情况下,该模型可以用作函数,如下所示:

for l in NNa.layers: 
    l.trainable=False
x=Input(size)
y=NNb(x)
z=NNa(y)

Predict方法将不起作用,因为损失函数应该是图形的一部分,并且predict方法返回np.array

答案 1 :(得分:0)

首先,使NNa不可训练。请注意,如果您的模型具有内部模型,则应递归执行此操作。

def makeUntrainable(layer):
    layer.trainable = False

    if hasattr(layer, 'layers'):
        for l in layer.layers:
            makeUntrainable(l)

makeUntrainable(NNa)

然后您有两个选择:

  • 将NNa附加到模型的末尾(请注意,y_truey_pred都将被更改)
    • 然后更改目标(使用NNa进行预测)以获得正确的结果,因为您的模型现在期望使用NNa而不是NNb的输出。
  • 创建一个在其中使用NNa的自定义损失函数,而无需更改目标

选项1-附加模型

inputs = NNb.inputs   
outputs = NNa(NNb.outputs) #make sure NNb is outputing 4 tensors to match NNa inputs   
fullModel = Model(inputs,outputs)

#changing the targets:
newY_train = NNa.predict(oldY_train)    

选项2-创建自定义损失

  

警告:请在训练此配置时测试NNa的重量是否真的冻结

from keras.losses import binary_crossentropy

def customLoss(true,pred):
    true = NNa(true)
    pred = NNa(pred)

    #use some of the usual losses or create your own
    binary_crossentropy(true,pred)

NNb.compile(optimizer=anything, loss = customLoss)