如何在tensorflow.keras模型中正确使用自定义损失(例如骰子系数)?

时间:2019-08-27 16:42:26

标签: python tensorflow keras loss

我试图在tensorflow.keras模型编译中输入自定义损失或度量函数,但出现错误。看起来输入没有正确传递。

我能够在以下张量流损失下运行模型编译/拟合:

loss = tf.keras.losses.MeanSquaredError() # Breaks if I remove ()
loss = tf.nn.sigmoid_cross_entropy_with_logits # Breaks if I add ()

我对语法上的差异感到困惑,这些语法使它们无法正常工作(也许是错误的?),如果我使用自定义骰子丢失,那么这两种语法都不起作用。

当我在下面运行自定义骰子丢失时,输入标签正确传递为batch_size*height*width,但输入logits正确传递为None,None,None,None(看起来不正确?),骰子丢失功能错误。我正在尝试进行批次优化,因此损失应该由model.fit计算每个批次。

def generalized_dice(labels, logits):
    smooth = 1e-17
    shape = tf.TensorShape(logits.shape).as_list()
    depth = int(shape[-1])
    labels = tf.one_hot(labels, depth, dtype=tf.float32)
    logits = tf.nn.softmax(logits)
    weights = 1.0 / (tf.reduce_sum(labels, axis=[0, 1, 2])**2)

    numerator = tf.reduce_sum(labels * logits, axis=[0, 1, 2])
    numerator = tf.reduce_sum(weights * numerator)

    denominator = tf.reduce_sum(labels + logits, axis=[0, 1, 2])
    denominator = tf.reduce_sum(weights * denominator)

    loss = 2.0*(numerator + smooth)/(denominator + smooth)
    return loss


def generalized_dice_loss(dice):
    return 1-dice 


model = tf.keras.Model(inputs=[input_x], outputs=[predictions])

loss = tf.keras.losses.MeanSquaredError() # Breaks if I remove (), with inputs being passed as None's 
# loss = tf.nn.sigmoid_cross_entropy_with_logits 
metric = generalized_dice # If no parantehses, labels passed as None's; 
# Both tf.nn.sigmoid_cross_entropy_with_logits and generalized_dice don't work if I add () saying inputs and labels must be provided; since I am training in batches, these need to be provided by the model during training, not when I am defining loss?  

model.compile(optimizer=tf.keras.optimizers.Adam(),  
              loss=loss, 
              metrics=[metric]) 

由于将骰子损失的输入作为None传递,所以

labels = tf.one_hot(labels, depth, dtype=tf.float32) 

结果:

TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64

0 个答案:

没有答案