检查检查点后(我们称之为模型1)我在其中获得了以下变量名称列表(为简单起见缩短了):
var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]
假设我有一个更大的模型2,并希望用模型1中的值初始化它的一部分。
Obtain variables from names as described here(因为我没有找到一种简单的方法):
def get_vars_by_name(names):
return [v for v in tf.global_variables() if v.name in names]
构建模型2和恢复的保护程序:
logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))
目前
saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
我收到错误:
"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")
请帮我找出我犯的错误。我还要感谢一种更简单的方式,因为这太可怕了。谢谢。
答案 0 :(得分:0)
一个简单的解决方法是猜测是否应该恢复变量。
def ignore_name(name):
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
return False
您应该可以直接通过
使用这个想法def get_vars_by_name(names):
return [v for v in tf.global_variables() if v.name in names and not ignore_name(v.name)]
这甚至允许使用ADAM训练模型,然后切换到SDG,反之亦然。