TF Objection Detection Zoo中的模型具有meta + ckpt文件,Frozen.pb文件和Saved_model文件。
我试图使用meta + ckpt文件进行进一步的训练,并为研究目的提取特定张量的权重。我看到这些模型没有任何可训练的变量。
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
上面的代码段给出了[]
列表。我也尝试使用以下内容。
vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)
我再次得到一个[]
列表。
这怎么可能?模型是否去除了变量?还是tf.Variable(trainable=False)
?在哪里可以获取具有有效可训练变量的meta + ckpt文件。我专门研究SSD +移动网络模型
更新:
以下是我用于还原的代码段。由于正在为某些应用程序创建自定义工具,因此该代码段位于类中。
def _importer(self):
sess = tf.InteractiveSession()
with sess.as_default():
reader = tf.train.import_meta_graph(self.metafile,
clear_devices=True)
reader.restore(sess, self.ckptfile)
def _read_graph(self):
sess = tf.get_default_session()
with sess.as_default():
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)
更新2:
我还尝试了以下代码段。简单的还原风格。
model_dir = 'ssd_mobilenet_v2/'
meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()
sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
with tf.Session() as sess:
reader = tf.train.import_meta_graph(meta,clear_devices=True)
reader.restore(sess,ckpt)
vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for var in vari:
print(var.name,"\n")
以上代码段还提供了[]
变量列表
答案 0 :(得分:2)
经过一些研究,对您问题的最终答案是是。很明显,直到您意识到variables
中的saved_model
目录为空。
对象检测模型zoo提供的检查点文件包含以下文件:
.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
|-- saved_model.pb
`-- variables
pipeline.config
是保存的模型的配置文件,frozen_inference_graph.pb
用于现成的推断。请注意,checkpoint
,model.ckpt.data-00000-of-00001
,model.ckpt.meta
和model.ckpt.index
都对应于检查点。 (Here,您会找到一个很好的解释)
因此,当您想获取可训练的变量时,唯一有用的是saved_model
目录。
使用SavedModel保存和加载您的模型-变量,图形和图形的元数据。这是一种与语言无关的,可恢复的,密封的序列化格式,使更高级别的系统和工具能够生成,使用和转换TensorFlow模型。
要恢复SavedModel
,可以使用api tf.saved_model.loader.load()
,此api包含一个称为tags
的参数,用于指定MetaGraphDef
的类型。因此,如果要获取可训练的变量,则在调用api时需要指定tag_constants.TRAINING
。
我试图调用此api来恢复变量,但它给了我说
的错误与标签“ train”相关的MetaGraphDef在SavedModel中找不到。要检查SavedModel中可用的标签集,请使用SavedModel CLI:
saved_model_cli
因此,我执行了此saved_model_cli
命令来检查SavedModel
中所有可用的标签。
#from directory saved_model
saved_model_cli show --dir . --all
,输出为
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
...
因此,train
中没有标签serve
,只有SavedModel
。因此,此处的SavedModel
仅用于张量流服务。这意味着当创建这些文件时,如果未使用标签training
指定这些文件,则无法从这些文件中恢复训练变量。
P.S .:以下代码是我用于还原SavedModel
的代码。设置tag_constants.TRAINING
时,无法完成加载,但是设置tag_constants.SERVING
时,加载成功,但变量为空。
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(variables)
P.P.S:我找到了用于创建SavedModel
here的脚本。可以看出,创建train
时确实没有SavedModel
标签。