使用tf.data API和样本权重进行培训

时间:2018-12-06 14:23:03

标签: python tensorflow tensorflow-datasets

我所有的训练图像都在tfrecords文件中。现在,它们以这样的标准方式使用:

dataset = dataset.apply(tf.data.experimental.map_and_batch(
            map_func=lambda x: preprocess(x, data_augmentation_options=data_augmentation), 
            batch_size=images_per_batch)

其中预处理返回来自tfrecord文件的解码图像和标签。

现在是新情况。我还需要每个示例的样本权重。因此,而不是

return image,label

在预处理中应该是

return image, label, sample_weight

但是,此sample_weight不在tfrecord文件中。它是根据每个课程的示例数量在训练开始时计算的。基本上,这是Python字典的weights [label] = sample_weights。

问题是如何在tf.data管道中使用这些样本权重。由于label是张量,因此不能用于索引Python字典。

1 个答案:

答案 0 :(得分:0)

有些问题在您的问题上不清楚,因为x是什么?如果您可以在问题中张贴整个代码示例,那就更好了。

我假设x与图像和标签一样张量。如果是这样,您可以使用地图功能将样本权重的张量添加到数据集中。如下(请注意,此代码未经测试):

def im_add_weight(image, label, sample_weight):
   #convert to tensor if they are not and make sure to us
   image= tf.convert_to_tensor(image, dtype= tf.float32)
   label = tf.convert_to_tensor(label, dtype= tf.float32)
   sample_weight = tf.convert_to_tensor(sample_weight, dtype= tf.float32)
   return image, label, sample_weight

dataset = dataset .map(
lambda image, label, sample_weight: tuple(tf.py_func(
    im_add_weight, [image, label,sample_weight], [tf.float32, tf.float32,tf.float32])))