在TensorFlow中重新训练(微调)预训练模型的特定层(保存为ckpt)

时间:2018-04-17 12:29:43

标签: tensorflow object-detection

我正在TF中做我的第一步,我觉得我很困惑......

我正在尝试微调(重新训练)我从object detection zoo下载的预训练模型。 更具体地说,我想仅训练ssd_mobilenet_v2_cocoBoxPredictor层:BoxEncodingPredictor(权重+偏差)和ClassPredictor(权重+偏差)。

根据我的理解,我需要执行以下步骤:

var_list = list(filter(box_predictor_filter, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
loss = tf.get_collection(tf.GraphKeys.LOSSES)
optim = tf.train.AdamOptimizer().minimize(loss=loss, var_list=var_list)

问题是我无法从图表中获取var_listloss。按照我的预期方式:

我正在导入" ssd_mobilenet_v2_coco"使用tf.train.import_meta_graph的图表,如下所示:

import os
import tensorflow as tf

ckpt_path = os.path.join(MODELS_PATH,"ssd_mobilenet_v2_coco_2018_03_29")

def import_graph_from_ckpt():
    saver = tf.train.import_meta_graph(os.path.join(ckpt_path, 'model.ckpt.meta'))
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
        graph = tf.get_default_graph()
    return graph

但所有变量列表都是空的:

my_graph = import_graph_from_ckpt()
print("------------ *** TRAINABLE_VARIABLES *** ------------")
print(my_graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
# output: []

我也尝试过不同类型的变量,所有变量都是空的。

get_operations()返回变量:

print("------------ *** get_operations VariableV2 *** ------------")
graph_vars = list(filter(lambda op: op.type == "VariableV2", my_graph.get_operations()))
print(len(graph_vars))
# output: 345

这怎么可能?如何使用可训练变量恢复图形?

0 个答案:

没有答案