在同一Tensorflow会话中从Saver加载两个模型

时间:2017-01-12 07:12:28

标签: python tensorflow

我有两个网络:一个Model生成输出,一个Adversary对输出进行分级。

两者都经过单独培训,但现在我需要在一次会议中将他们的输出结合起来。

我试图实施此帖中提出的解决方案:Run multiple pre-trained Tensorflow nets at the same time

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数model_saver.restore()的调用似乎什么都不做。在另一个模块中,我使用tf.train.Saver(tf.global_variables())的保护程序,它可以恢复检查点。

模型有model.tvars = tf.trainable_variables()。为了检查发生了什么,我使用sess.run()在恢复之前和之后提取tvars。每次使用初始随机分配的变量并且未分配检查点的变量时。

有关为什么model_saver.restore()似乎什么都不做的任何想法?

3 个答案:

答案 0 :(得分:19)

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它。

为了诊断问题,我手动循环遍历每个变量并逐个分配它们。然后我注意到在分配变量后,名称会改变。这在此处描述:TensorFlow checkpoint save and read

根据该帖子中的建议,我在自己的图表中运行了每个模型。这也意味着我必须在自己的会话中运行每个图形。这意味着以不同方式处理会话管理。

首先我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后两个会议

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我在每个会话中初始化变量并单独恢复每个图表

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

每当需要每个会话时,我都会在此会话中使用tf包装任何with sess.as_default():个函数。最后,我手动关闭会话

sess.close()
adv_sess.close()

答案 1 :(得分:0)

请检查以下内容:

adv_varlist = {v.name.lstrip("avd/")[:-2]: v 

应该是“ adv”,不是

答案 2 :(得分:0)

标记为正确的答案并没有告诉我们如何将两个不同的模型显式加载到一个会话中,这是我的答案:

  1. 为要加载的模型创建两个不同的名称范围。

  2. 初始化两个保护程序,这两个保护程序将为两个不同网络中的变量加载参数。

  3. 从相应的检查点文件
  4. 加载。

with tf.Session() as sess:
    with tf.name_scope("net1"):
      net1 = Net1()
    with tf.name_scope("net2"):
      net2 = Net2()

    net1_varlist = {v.op.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)

    net2_varlist = {v.op.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)

    net1_saver.restore(sess, "net1.ckpt")
    net2_saver.restore(sess, "net2.ckpt")