TF异物检测Zoo模型没有可训练的变量吗?

时间:2019-06-11 15:42:14

标签: python tensorflow object-detection object-detection-api pre-trained-model

TF Objection Detection Zoo中的模型具有meta + ckpt文件,Frozen.pb文件和Saved_model文件。

我试图使用meta + ckpt文件进行进一步的训练,并为研究目的提取特定张量的权重。我看到这些模型没有任何可训练的变量。

vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

上面的代码段给出了[]列表。我也尝试使用以下内容。

vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)

我再次得到一个[]列表。

这怎么可能?模型是否去除了变量?还是tf.Variable(trainable=False)?在哪里可以获取具有有效可训练变量的meta + ckpt文件。我专门研究SSD +移动网络模型

更新:

以下是我用于还原的代码段。由于正在为某些应用程序创建自定义工具,因此该代码段位于类中。

def _importer(self):
    sess = tf.InteractiveSession()
    with sess.as_default():
        reader = tf.train.import_meta_graph(self.metafile,
                                            clear_devices=True)
        reader.restore(sess, self.ckptfile)

def _read_graph(self):
    sess = tf.get_default_session()
    with sess.as_default():
        vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print(vars)

更新2:

我还尝试了以下代码段。简单的还原风格。

model_dir = 'ssd_mobilenet_v2/'

meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()

sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        reader = tf.train.import_meta_graph(meta,clear_devices=True)
        reader.restore(sess,ckpt)

        vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        for var in vari:
            print(var.name,"\n")

以上代码段还提供了[]变量列表

1 个答案:

答案 0 :(得分:2)

经过一些研究,对您问题的最终答案是。很明显,直到您意识到variables中的saved_model目录为空。

对象检测模型zoo提供的检查点文件包含以下文件:

.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
    |-- saved_model.pb
    `-- variables

pipeline.config是保存的模型的配置文件,frozen_inference_graph.pb用于现成的推断。请注意,checkpointmodel.ckpt.data-00000-of-00001model.ckpt.metamodel.ckpt.index 都对应于检查点。 (Here,您会找到一个很好的解释)

因此,当您想获取可训练的变量时,唯一有用的是saved_model目录。

  

使用SavedModel保存和加载您的模型-变量,图形和图形的元数据。这是一种与语言无关的,可恢复的,密封的序列化格式,使更高级别的系统和工具能够生成,使用和转换TensorFlow模型。

要恢复SavedModel,可以使用api tf.saved_model.loader.load(),此api包含一个称为tags的参数,用于指定MetaGraphDef的类型。因此,如果要获取可训练的变量,则在调用api时需要指定tag_constants.TRAINING

我试图调用此api来恢复变量,但它给了我说

的错误
  

与标签“ train”相关的MetaGraphDef在SavedModel中找不到。要检查SavedModel中可用的标签集,请使用SavedModel CLI:saved_model_cli

因此,我执行了此saved_model_cli命令来检查SavedModel中所有可用的标签。

#from directory saved_model
saved_model_cli show --dir . --all

,输出为

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
  ...

因此,train中没有标签serve,只有SavedModel。因此,此处的SavedModel仅用于张量流服务。这意味着当创建这些文件时,如果未使用标签training指定这些文件,则无法从这些文件中恢复训练变量。

P.S .:以下代码是我用于还原SavedModel的代码。设置tag_constants.TRAINING时,无法完成加载,但是设置tag_constants.SERVING时,加载成功,但变量为空。

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
  variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  print(variables)

P.P.S:我找到了用于创建SavedModel here的脚本。可以看出,创建train时确实没有SavedModel标签。