将特征提取器导入Tensorflow对象检测API

时间:2019-09-24 16:26:34

标签: tensorflow keras object-detection object-detection-api resnet

我正在尝试导入基于 ResNet 的体系结构作为 SSD 的功能提取器。 由于我的应用程序需要实时,因此我尝试使用不同的ResNet深度(resnet-10,resnet-16,resnet-24等)。

我已经阅读了Defining a new Faster R-CNN or SSD Feature Extractor教程,并试图通过object_detection/models/ssd_resnet_v1_ppn_feature_extractor(以及fpn)指导自己。

问题是,我正在从tensorflowcv python包中获取resnet-10(例如)体系结构,该体系结构是通过Keras实现的,并且我将函数resnet10_base定义为重定义get_resnet的内容,如下所示,因此它忽略了使其成为分类器的最后一层。

我对get_resnet的重新定义:

 01  def resnet10_base(model_name='resnet10', root=os.path.join("~", ".tensorflow", "    models")):
 02   layers = [1,1,1,1]
 03   assert (sum(layers) * 2 + 2 == blocks)
 04 
 05   init_block_channels = 64
 06   channels_per_layers = [64, 128, 256, 512]
 07 
 08   channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)]
 09 
 10   if pretrained:
 11     if (model_name is None) or (not model_name):
 12       raise ValueError("Parameter `model_name` should be properly initialized fo    r loading pretrained model.")
 13     from .model_store import download_state_dict
 14     net.state_dict, net.file_path = download_state_dict(
 15       model_name=model_name,
 16       local_model_store_dir_path=root)
 17 
 18   return net

我不明白如何创建extract_features()要求的SSDKerasFeatureExtractor函数。例如,是否必须创建TFSession并在其中运行网络?还是我必须返回张量?我该如何实现?

在此先感谢您带来的不便。

0 个答案:

没有答案