用于张量流对象检测的平衡数据集

时间:2018-06-02 16:10:25

标签: tensorflow dataset object-detection

我目前想要使用Tensorflows Object Detection API来解决我的自定义问题。 我已经创建了数据集,但它非常不平衡。 数据集有3个类,我的主要问题是,一个类有大约16k个样本,另一个类只有大约2.5k个样本。

所以我认为我必须平衡数据集。有人告诉我,有一种称为样本/类权重(不确定这是否100%正确),这样可以平衡训练样本,这样最大的类对训练的影响小于最小的类。

我无法找到这种平衡方法。有人请求给我一个暗示从哪里开始?

谢谢!

1 个答案:

答案 0 :(得分:0)

你可以做正常的交叉熵,给你一个? x 1张量,X的损失

如果您希望班级编号N计数T倍,您可以

X = X * tf.reduce_sum(tf.multiply(one_hot_label, class_weight), axis = 1)

tf.multiply

根据您想要的重量来缩放标签,

tf.reduce_sum

将标签向量a转换为标量,所以最终会得到一个? x 1张量充满了类权重。然后,您只需将损失张量乘以权重张量即可得到预期的结果。

由于一个类比另一个类多6.4倍,我将分别将权重1和6.4应用于更常见和不太常见的类。这意味着每次发生不太常见的类时,它的效果是更常见的类的6.4倍,所以就像从每个类中看到相同数量的样本一样。

您可能希望对其进行修改,以便加权累加到类的数量。这符合默认情况,所有权重均为1.在这种情况下,我们有1 /7.4和6.4 / 7.4