tensorflow仅从检查点恢复一些变量

时间:2018-05-02 08:11:25

标签: python tensorflow deep-learning

检查检查点后(我们称之为模型1)我在其中获得了以下变量名称列表(为简单起见缩短了):

var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]

假设我有一个更大的模型2,并希望用模型1中的值初始化它的一部分。

Obtain variables from names as described here(因为我没有找到一种简单的方法):

def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names]

构建模型2和恢复的保护程序:

logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))

目前

saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))

我收到错误:

"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")

请帮我找出我犯的错误。我还要感谢一种更简单的方式,因为这太可怕了。谢谢。

1 个答案:

答案 0 :(得分:0)

一个简单的解决方法是猜测是否应该恢复变量。

def ignore_name(name):
    if name.endswith('/Adam') or name.endswith('/Adam_1'):
        return True
    return False

您应该可以直接通过

使用这个想法
def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names and not ignore_name(v.name)]

这甚至允许使用ADAM训练模型,然后切换到SDG,反之亦然。