我目前正在研究复杂的模型(对于我的弱脑来说太复杂了),我同时开始使用张量流。
def load_model():
checkpoint = tf.train.latest_checkpoint('my_checkpoint')
new_graph = tf.Graph() #the one that we need to restore
with tf.Session(graph=new_graph) as sess:
saver = tf.train.import_meta_graph(checkpoint + '.meta')
saver.restore(sess, checkpoint)
print("Model restored")
return graph_seg
所以,这是我加载以前保存和训练过的模型的功能。 显然,它似乎工作正常,并加载我需要的操作。
现在,我想创建我的主模型:
def create_main_model(X,Y):
with tf.name_scope("G_on_real"):
with tf.variable_scope("G"):
Y_channels = int(Y.get_shape()[-1])
fake_Y = create_generator(X, Y_channels)
#Blablah, we define all the things that we need.
loaded_graph = load_model()
with loaded_graph.as_default():
with tf.Session() as sess:
results = sess.run(fake_Y) # Trick here !
问题是我希望能够使用 fake_Y 作为加载模型的输入。训练阶段的每一步都会生成 fake_Y 。通过获得的输出,我想计算一个新的损失并将其整合到我的主模型的总损失中。
我知道 fake_Y 并不存在于 loaded_graph 图表中。而tensorflow会给我这个错误:
ValueError: Fetch argument <tf.Tensor 'G_on_real/G/decoder/deconv_3/Tanh:0' shape=(1, 128, 128, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("G_on_real/G/decoder/deconv_3/Tanh:0", shape=(1, 128, 128, 3), dtype=float32) is not an element of this graph.)
所以我认为我的问题非常基本。我只需要找到一种方法来连接这两个图。但是张量流的工作流程在我看来并不是很清楚。
在两个断开连接的图表的简单情况下,我可以在 fake_Y 上调用 sess.run 。但是在这里,由于我还在构建我的主模型,所以不可能,所以我不能用一堆未初始化的东西来调用会话。
那么,有没有办法在我的两个图表之间共享 fake_Y (一个初始化而另一个没有)?
任何帮助将不胜感激! 谢谢
编辑:我找到了直接在当前图表中加载模型的解决方案(然后我不再需要创建新图表了)。我不知道哪种解决方案最好。但无论如何,问题仍然存在:如何在 fake_Y 上运行加载/初始化模型(与主要模型共享当前图形):在我们的函数中调用sess.run构建模型非常奇怪且不可行,因为某些对象尚未初始化。答案 0 :(得分:0)
正确的解决方案是在当前图表中加载新模型。
答案 1 :(得分:0)
如上所述。我初始化并加载了我的模型。然后,代码如下:
with tf.variable_scope("segmentation_model"):
seg_model = create_previous_model(X, Y)
checkpoint_seg = tf.train.latest_checkpoint('my_checkpoint')
restore_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="segmentation_model")
restore_variables_keys = [x.name.replace('segmentation_model/', '').replace(':0', '') for x in restore_variables]
dict_restore_variables = dict(zip(restore_variables_keys, restore_variables))
saver = tf.train.Saver(dict_restore_variables)
with tf.Session() as sess_seg:
saver.restore(sess_seg, checkpoint_seg)
所以我创建了之前没有重量的模型。当我打电话给 saver.restore 时,它将包含未来加载的模型。请注意,已加载的模型已经使用不同的X和Y进行了训练,但这没关系,对吧?
问题是我的主模型和我想加载的模型具有相同的变量。所以我必须创建我想要在特定范围内加载的模型,然后使用字典在正确的范围内加载模型。
在我的主模型中,我有这一行:
seg_model = load_segmentation_model(X,Y) #which do the job described above
output_seg_model = seg_models.outputs
#Few lines to define losses and other tensors of the main model.
我已经通过打印“segmentation_model”范围中包含的变量检查了模型是否被有效加载。一切似乎都没问题。
Tensorflow不会返回任何错误。但是当我显示加载模型的输出(这是一个图像)时,图像似乎是由未经训练的生成器生成的(我正在生成模型)。事实上,当我不打电话给 saver.restore 时,输出保持不变。所以这个电话没用,但没有任何错误:'(
我是否想念一些明显的东西?
谢谢!