我正在TF中做我的第一步,我觉得我很困惑......
我正在尝试微调(重新训练)我从object detection zoo下载的预训练模型。
更具体地说,我想仅训练ssd_mobilenet_v2_coco的BoxPredictor
层:BoxEncodingPredictor
(权重+偏差)和ClassPredictor
(权重+偏差)。
根据我的理解,我需要执行以下步骤:
var_list = list(filter(box_predictor_filter, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
loss = tf.get_collection(tf.GraphKeys.LOSSES)
optim = tf.train.AdamOptimizer().minimize(loss=loss, var_list=var_list)
问题是我无法从图表中获取var_list
和loss
。按照我的预期方式:
我正在导入" ssd_mobilenet_v2_coco"使用tf.train.import_meta_graph
的图表,如下所示:
import os
import tensorflow as tf
ckpt_path = os.path.join(MODELS_PATH,"ssd_mobilenet_v2_coco_2018_03_29")
def import_graph_from_ckpt():
saver = tf.train.import_meta_graph(os.path.join(ckpt_path, 'model.ckpt.meta'))
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph = tf.get_default_graph()
return graph
但所有变量列表都是空的:
my_graph = import_graph_from_ckpt()
print("------------ *** TRAINABLE_VARIABLES *** ------------")
print(my_graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
# output: []
我也尝试过不同类型的变量,所有变量都是空的。
但get_operations()
返回变量:
print("------------ *** get_operations VariableV2 *** ------------")
graph_vars = list(filter(lambda op: op.type == "VariableV2", my_graph.get_operations()))
print(len(graph_vars))
# output: 345
这怎么可能?如何使用可训练变量恢复图形?