如何在TensorFlow对象检测API中

时间:2019-04-12 12:19:17

标签: python tensorflow object-detection object-detection-api transfer-learning

我目前正在使用TensorFlow Object Detection API,并试图从模型动物园微调经过预先训练的Faster-RCNN。当前,如果我选择与原始网络中使用的数字不同的类数,则它不会简单地初始化来自SecondStageBoxPredictor/ClassPredictor的权重和偏差,因为它现在具有与原始ClassPredictor不同的维度。但是,由于我要训练网络的所有类都是原始网络已经训练确定的类,因此我想保留与我要在SecondStageBoxPredictor/ClassPredictor和修剪所有其他值,而不是简单地从头开始初始化这些值(类似于this function的行为)。

这是否可能,如果可以,我将如何在Estimator中修改该层的结构?

n.b。 This question提出了类似的要求,他们的回应是忽略网络输出中不相关的类-但是,在这种情况下,我试图微调网络,并且我认为这些冗余类的存在会使训练变得复杂/评估过程?

1 个答案:

答案 0 :(得分:1)

如果您想训练网络的所有课程都是经过训练的网络识别,您可以简单地使用网络进行检测,不是吗?

但是,如果您有更多的类,并且想要进行转移学习,则可以通过设置以下内容从检查点还原尽可能多的变量:

fine_tune_checkpoint_type: 'detection'
load_all_detection_checkpoint_vars: True

位于管道配置文件中的字段train_config中。

最后,通过查看计算图,可以看出SecondStageBoxPredictor/ClassPredictor/weights的形状取决于输出类的数量。 enter image description here

请注意,在tensorflow中,您只能在变量级别恢复,如果两个变量的形状不同,则一个不能使用一个来初始化另一个。因此,在您的情况下,保留weights变量的某些值的想法是不可行的。