TF2.0中Keras损失中的sample_weight参数的奇怪形状要求

时间:2019-08-18 03:58:03

标签: tensorflow keras

根据TF文档,sample_weight参数的形状可以为[batch_size]。相关文档引用如下:

  

sample_weight:可选Tensor,其等级为0,或与y_true相同,或者可以广播到y_truesample_weight充当损耗的系数。如果提供了标量,则损耗将简单地按给定值缩放。如果sample_weight是大小[batch_size]的张量,则该批次的每个样本的总损失将由sample_weight向量中的相应元素重新缩放。如果sample_weight的形状与y_pred的形状匹配,则y_pred的每个可测量元素的损失将按sample_weight的相应值缩放。

但是,我不明白为什么以下代码不起作用。

import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss)

代码抛出此错误:

  

tensorflow.python.framework.errors_impl.InvalidArgumentError:无法挤压dim [0],预期尺寸为1,得到5 [Op:Squeeze]

如果我扩展gtpredsample_weights的尺寸,则它将正常工作并输出期望的损耗值3.0849898。

import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

# expand dims
gt = tf.expand_dims(gt, 1)
pred = tf.expand_dims(pred, 1)
sample_weights = tf.expand_dims(sample_weights, 1)

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss)  # loss is 3.0849898

1 个答案:

答案 0 :(得分:0)

问题不在于sample_weight的形状。 predgt的形状应为[batch_size, n_labels]

import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

# expand dims
gt = tf.expand_dims(gt, 1)
pred = tf.expand_dims(pred, 1)
print(gt.shape, pred.shape) #(5, 1) (5, 1)

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss)  # loss is 3.0849898