如何使用tf.estimator中的预训练模型进行微调

时间:2018-07-16 02:41:19

标签: tensorflow tensorflow-estimator

我使用MMDNN工具从caffe转换了模型,将caffe模型转换为save_model张量流样式。这是一个resnet18模型,我在最后只删除了几层,希望我可以在tf.estimator的model_fn中加载此体系结构,并手动添加一些额外的层来完成我的工作。 如本教程所建议,我可以使用loader.load方法来加载save_model。但是我只想在估计器中使用它,而我需要在model_fn函数中定义架构。我搜索了SO和github,但是没有一个非常具体的工作流程来执行该操作,有人可以帮我吗?

1 个答案:

答案 0 :(得分:1)

这是使用tf.Estimator进行微调的一种方法:

  1. 使用与保存的模型相同的变量名/范围来定义模型
  2. 使用tf.estimator的热启动功能使用保存的权重初始化新模型。这是一个代码片段:

    if fine_tuning:
        ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=path_saved_model,
                                            vars_to_warm_start='.*')
    else:
        ws = None
    
    estimator = tf.estimator.Estimator(model_fn=model_function,
                                                warm_start_from=ws,
                                                ...
                                                )
    

这将初始化在您当前定义的图形和保存的模型之间共享名称的任何变量。