如何在训练当前图形时使用恢复的图形

时间:2018-02-08 08:23:35

标签: tensorflow graph restore training-data

我目前正在研究复杂的模型(对于我的弱脑来说太复杂了),我同时开始使用张量流。

def load_model():
     checkpoint = tf.train.latest_checkpoint('my_checkpoint')
     new_graph = tf.Graph() #the one that we need to restore

     with tf.Session(graph=new_graph) as sess:
         saver = tf.train.import_meta_graph(checkpoint + '.meta')
         saver.restore(sess, checkpoint)

     print("Model restored")
     return graph_seg

所以,这是我加载以前保存和训练过的模型的功能。 显然,它似乎工作正常,并加载我需要的操作。

现在,我想创建我的主模型:

def create_main_model(X,Y):
    with tf.name_scope("G_on_real"):
        with tf.variable_scope("G"):
            Y_channels = int(Y.get_shape()[-1])
            fake_Y = create_generator(X, Y_channels)

    #Blablah, we define all the things that we need.

    loaded_graph = load_model()

    with loaded_graph.as_default():
        with tf.Session() as sess:
            results = sess.run(fake_Y) # Trick here !

问题是我希望能够使用 fake_Y 作为加载模型的输入。训练阶段的每一步都会生成 fake_Y 。通过获得的输出,我想计算一个新的损失并将其整合到我的主模型的总损失中。

我知道 fake_Y 并不存在于 loaded_graph 图表中。而tensorflow会给我这个错误:

ValueError: Fetch argument <tf.Tensor 'G_on_real/G/decoder/deconv_3/Tanh:0' shape=(1, 128, 128, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("G_on_real/G/decoder/deconv_3/Tanh:0", shape=(1, 128, 128, 3), dtype=float32) is not an element of this graph.) 

所以我认为我的问题非常基本。我只需要找到一种方法来连接这两个图。但是张量流的工作流程在我看来并不是很清楚。

在两个断开连接的图表的简单情况下,我可以在 fake_Y 上调用 sess.run 。但是在这里,由于我还在构建我的主模型,所以不可能,所以我不能用一堆未初始化的东西来调用会话。

那么,有没有办法在我的两个图表之间共享 fake_Y (一个初始化而另一个没有)?

任何帮助将不胜感激! 谢谢

编辑:我找到了直接在当前图表中加载模型的解决方案(然后我不再需要创建新图表了)。我不知道哪种解决方案最好。但无论如何,问题仍然存在:如何在 fake_Y 上运行加载/初始化模型(与主要模型共享当前图形):在我们的函数中调用sess.run构建模型非常奇怪且不可行,因为某些对象尚未初始化。

2 个答案:

答案 0 :(得分:0)

正确的解决方案是在当前图表中加载新模型。

答案 1 :(得分:0)

如上所述。我初始化并加载了我的模型。然后,代码如下:

with tf.variable_scope("segmentation_model"):
    seg_model = create_previous_model(X, Y)

checkpoint_seg = tf.train.latest_checkpoint('my_checkpoint')
restore_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="segmentation_model")
restore_variables_keys = [x.name.replace('segmentation_model/', '').replace(':0', '') for x in restore_variables]
dict_restore_variables = dict(zip(restore_variables_keys, restore_variables))

saver = tf.train.Saver(dict_restore_variables)
with tf.Session() as sess_seg:
    saver.restore(sess_seg, checkpoint_seg)

所以我创建了之前没有重量的模型。当我打电话给 saver.restore 时,它将包含未来加载的模型。请注意,已加载的模型已经使用不同的X和Y进行了训练,但这没关系,对吧?

问题是我的主模型和我想加载的模型具有相同的变量。所以我必须创建我想要在特定范围内加载的模型,然后使用字典在正确的范围内加载模型。

在我的主模型中,我有这一行:

seg_model = load_segmentation_model(X,Y) #which do the job described above
output_seg_model = seg_models.outputs

#Few lines to define losses and other tensors of the main model.

我已经通过打印“segmentation_model”范围中包含的变量检查了模型是否被有效加载。一切似乎都没问题。

Tensorflow不会返回任何错误。但是当我显示加载模型的输出(这是一个图像)时,图像似乎是由未经训练的生成器生成的(我正在生成模型)。事实上,当我不打电话给 saver.restore 时,输出保持不变。所以这个电话没用,但没有任何错误:'(

我是否想念一些明显的东西?

谢谢!