多输出模型的 Keras 自定义损失函数?

时间:2021-01-07 12:26:30

标签: python tensorflow machine-learning keras deep-learning

(您可以在下面的部分中找到最少的代码和 colab 链接)
我有一个 keras 模型定义为:

model_in = Input((None, None, 3))
atts_out, clss_out, regs_out = [], [], []
for i in range(5):
  ...  # Do what I want, and produce att_out, cls_out, reg_out Tensors
  atts_out.append(att_out)
  clss_out.append(cls_out)
  regs_out.append(reg_out)
model = keras.Model(model_in, [atts_out, clss_out, regs_out])

现在我想在 model.compile() 中使用自定义损失函数,但我对其参数(y_true 和 y_pred)形状感到困惑。假设我想对 regs_out 的模型输出应用 MSE,我应该如何完成下面的代码?

def hybrid_loss(y_true, y_pred):
  true_regs_out = ...  # How to extract regs_out from y_true?
  pred_regs_out = ...  # How to extract regs_out from y_pred?

  regs_loss = 0
  for true_reg_out, pred_reg_out in zip(true_regs_out, pred_regs_out):
    # What is the shape of true_reg_out and true_reg_out? is it (batch_size, shape of att_out)?
    regs_loss += (true_reg_out - pred_reg_out)**2

  return regs_loss

完整的最小示例

这是另一个例子,一个完整的:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *


def f1(a, b):
  return tf.reduce_sum((a-b)**1)

def f2(a, b):
  return tf.reduce_sum((a-b)**2)

def f3(a, b):
  return tf.reduce_sum((a-b)**3)

def hybrid_loss(y_true, y_pred):
  # I want to apply f1 on x1, f2 on x2, f3 on x3
  x1_true = ...   # how to extract x1_true
  x1_pred = ...   # how to extract x1_pred
  x2_true = ...   # how to extract x2_true
  x2_pred = ...   # how to extract x2_pred
  x3_true = ...   # how to extract x3_true
  x3_pred = ...   # how to extract x3_pred
  return f1(x1_true, x1_pred) + f1(x2_true, x2_pred) + f1(x2_true, x2_pred)


model_in = Input((3, ))
x1 = Dense(10)(model_in)
x2 = Dense(20)(model_in)
x3 = Dense(30)(model_in)
model = keras.Model(model_in, [[x1, x2], x3])

model.compile('adam', hybrid_loss)

t0 = np.random.rand(5, 3)
t1 = np.random.rand(5, 10) + 10
t2 = np.random.rand(5, 20) + 20
t3 = np.random.rand(5, 30) + 30

model.fit(t0, [[t1, t2], t3], batch_size = 2)

Link to work on the code in colab
完成上面的代码解决了我的问题,但如果您能提供有关 y_truey_pred 类型和形状的详细信息,我将不胜感激

0 个答案:

没有答案