如何在Tensorflow Object Detection API中初始化卷积层的权重?

时间:2019-03-15 09:01:25

标签: tensorflow object-detection-api

我遵循此tutorial来实现Tensorflow对象检测API。

首选方式是使用预先训练的模型。

但是在某些情况下,我们需要从头开始培训。

为此,我们只需要将配置文件中的两行注释为

#fine_tune_checkpoint: "object_detection/data/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt"
#from_detection_checkpoint: true 

如果我想使用Xavier重量初始化来初始化重量,该怎么办?

2 个答案:

答案 0 :(得分:1)

configuration protobuf definition中可以看到,可以使用3种初始化器:

  • TruncatedNormalInitializer truncated_normal_initializer
  • VarianceScalingInitializervariance_scaling_initializer
  • RandomNormalInitializer random_normal_initializer

您正在寻找VarianceScalingInitializer。它是通用的初始化程序,您可以通过设置factor=1.0, mode='FAN_AVG'基本上将其转换为Xavier初始化程序,如the documentation中所述。

因此,通过将初始化程序设置为

initializer {
    variance_scaling_initializer {
        factor: 1.0
        uniform: true
        mode: FAN_AVG
    }
}

在您的配置中,您将获得Xavier初始化程序。

但是,即使您需要训练新数据,也可以考虑使用预先训练的网络作为初始化,而不是随机初始化。有关更多详细信息,请参见this article

答案 1 :(得分:0)

mobilenet_v1功能提取器从research / slim / nets导入骨干网:

25:   from nets import mobilenet_v1

移动网络的代码根据specification实例化层,如下所示:

net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel, stride=conv_def.stride, scope=end_point)

请参阅 https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py#L264

如您所见,没有kwargs传递给conv2d调用,因此对于当前代码,您无法指定将使用哪个weights_initializer

但是,by default初始化程序无论如何都是Xavier,所以您很幸运。

我必须说,在某些辅助任务上未对特征提取器进行预训练的训练和目标检测模型可能会失败。