从TensorFlow中的检查点恢复后修改变量名称

时间:2017-01-31 20:49:53

标签: python tensorflow neural-network deep-learning

让我们说我有两种不同配置的型号。我训练和检查他们。

lower triangular

我想在同一会话中加载这些模型,以便我可以同时使用它们。对于上述两个模型,变量名称大致相同,因此为了避免名称损坏,我为每个模型添加了名称范围。

def train_and_save_model_using_config1():
    ...


def train_and_save_model_using_config2():
    ...

要从with tf.variable_scope("config1"): m1 = load_model_from_ckpt_with_config1() with tf.variable_scope("config2"): m2 = load_model_from_ckpt_with_config2() 的检查点恢复,我会收集变量和变量名称,但希望使用适当的范围重命名。

config1

但是我收到以下错误:

path = get_path_of_config1()
var_names = tf.contrib.framework.list_variables(path)
vars = {}
for name, shape in var_names:
    var = tf.contrib.framework.load_variable(path, name)
    vars["config1/" + name] = var

saver = tf.train.Saver(var_list=vars)
saver.restore(sess, tf.train.latest_checkpoint(path))

1 个答案:

答案 0 :(得分:0)

vars中的saver = tf.train.Saver(var_list=vars)字典必须是一个其值是当前会话中当前图形的 tf.Variable 引用的字典。
但是对于您的情况var = tf.contrib.framework.load_variable(path, name),问题在于var numpy.ndarray