我所有的训练图像都在tfrecords文件中。现在,它们以这样的标准方式使用:
dataset = dataset.apply(tf.data.experimental.map_and_batch(
map_func=lambda x: preprocess(x, data_augmentation_options=data_augmentation),
batch_size=images_per_batch)
其中预处理返回来自tfrecord文件的解码图像和标签。
现在是新情况。我还需要每个示例的样本权重。因此,而不是
return image,label
在预处理中应该是
return image, label, sample_weight
但是,此sample_weight不在tfrecord文件中。它是根据每个课程的示例数量在训练开始时计算的。基本上,这是Python字典的weights [label] = sample_weights。
问题是如何在tf.data管道中使用这些样本权重。由于label是张量,因此不能用于索引Python字典。
答案 0 :(得分:0)
有些问题在您的问题上不清楚,因为x是什么?如果您可以在问题中张贴整个代码示例,那就更好了。
我假设x与图像和标签一样张量。如果是这样,您可以使用地图功能将样本权重的张量添加到数据集中。如下(请注意,此代码未经测试):
def im_add_weight(image, label, sample_weight):
#convert to tensor if they are not and make sure to us
image= tf.convert_to_tensor(image, dtype= tf.float32)
label = tf.convert_to_tensor(label, dtype= tf.float32)
sample_weight = tf.convert_to_tensor(sample_weight, dtype= tf.float32)
return image, label, sample_weight
dataset = dataset .map(
lambda image, label, sample_weight: tuple(tf.py_func(
im_add_weight, [image, label,sample_weight], [tf.float32, tf.float32,tf.float32])))