让我们说我有两种不同配置的型号。我训练和检查他们。
lower triangular
我想在同一会话中加载这些模型,以便我可以同时使用它们。对于上述两个模型,变量名称大致相同,因此为了避免名称损坏,我为每个模型添加了名称范围。
def train_and_save_model_using_config1():
...
def train_and_save_model_using_config2():
...
要从with tf.variable_scope("config1"):
m1 = load_model_from_ckpt_with_config1()
with tf.variable_scope("config2"):
m2 = load_model_from_ckpt_with_config2()
的检查点恢复,我会收集变量和变量名称,但希望使用适当的范围重命名。
config1
但是我收到以下错误:
path = get_path_of_config1()
var_names = tf.contrib.framework.list_variables(path)
vars = {}
for name, shape in var_names:
var = tf.contrib.framework.load_variable(path, name)
vars["config1/" + name] = var
saver = tf.train.Saver(var_list=vars)
saver.restore(sess, tf.train.latest_checkpoint(path))
答案 0 :(得分:0)
vars
中的saver = tf.train.Saver(var_list=vars)
字典必须是一个其值是当前会话中当前图形的 tf.Variable 引用的字典。
但是对于您的情况var = tf.contrib.framework.load_variable(path, name)
,问题在于var
是 numpy.ndarray