如何在tensorflow中为非分类对象创建一个类?

时间:2017-10-30 10:33:08

标签: tensorflow deep-learning classification object-detection

您好我已经建立了我的CNN,有两个类别的狗和猫,我已经训练了这个,现在我能够对狗和猫图像进行分类。但是,如果我想为新的未分类对象引入一个类呢?例如,如果我用网络图像为我的网络提供网络,则网络会给我一个错误的分类。我想用新的非分类对象的第三类构建我的网络。但是我怎样才能建立这个第三类。我必须使用哪些图像来获得与狗或猫不同的新物体的课程? 实际上在我的网络结束时,我使用Softmax,我的代码是使用tensorflow开发的。有人能给我一些建议吗?谢谢

2 个答案:

答案 0 :(得分:4)

您需要在网络中添加第三个“其他”类。有几种方法可以解决它。一般来说,如果您有一个想要检测的类,那么您应该有该类的示例,这样您就可以将没有猫或狗的图像添加到标有新类的训练数据中。然而,这有点棘手,因为根据定义,新类是宇宙中的所有东西,但是狗和猫,所以你不可能期望有足够的数据来训练它。但实际上,如果你有足够的例子,网络可能会知道只要前两个没有触发第三个类。

我过去使用的另一个选项是对“默认”类进行建模,使其与常规类略有不同。因此,您可以只是明确地说它不是激活猫或狗神经元的任何东西,而不是试图真正了解什么是“非猫或狗”图像。我通过将最后一层从softmax替换为sigmoids来做到这一点(因此损失将是sigmoid交叉熵而不是softmax交叉熵,并且输出将不再是分类概率分布,但老实说它没有做多少在我的情况下性能差异),然后将“默认”类表示为1减去每个其他类的最大激活值。因此,如果没有一个类激活0.5更大(即估计该类的概率为50%),则“默认”类将是得分最高的类。您可以探索其他类似的方案。

答案 1 :(得分:1)

您应该只向数据集添加既不是狗也不是猫的图像,将它们标记为import matplotlib.pyplot as plt samples = [sample(5, 2, 1.0, 2.0) for _ in range(10000)] xs, ys = zip(*samples) plt.scatter(xs, ys, s=0.1) plt.axis("equal") plt.show() ,并将"Other"视为所有代码中的普通类。特别是你将获得超过3个类的softmax。

您正在使用的图像可以是任何图像(当然除了猫和狗),但应该与您在使用网络时可能要测试的图像相同。因此,例如,如果您知道您将测试狗,猫和其他动物的图像,则与其他动物一起训练,而不是用鲜花图片。如果你不知道你将要测试什么,尝试从不同的来源获得非常多样化的图像,以便网络很好地了解这个类是“除了猫和狗之外的任何东西”(广泛的图像在现实世界中属于这一类别的应该反映在您的训练数据集中。)