我想要做的是同时运行多个预先训练好的Tensorflow网络。因为每个网络中的一些变量的名称可以是相同的,所以常见的解决方案是在创建网络时使用名称范围。但问题是我训练了这些模型并将训练过的变量保存在几个检查点文件中。在创建网络时使用名称范围后,我无法从检查点文件加载变量。
例如,我训练了一个AlexNet,我想比较两组变量,一组来自纪元10(保存在文件epoch_10.ckpt中),另一组来自纪元50(保存在文件epoch_50.ckpt)。因为这两者是完全相同的网,所以内部变量的名称是相同的。我可以使用
创建两个网with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
但是,我无法从.ckpt文件中加载训练过的变量,因为当我训练这个网时,我没有使用名称范围。即使我可以将名称范围设置为" net1"当我训练网时,这阻止我加载net2的变量。
我试过了:
with tf.name_scope("net1"):
mySaver.restore(sess, 'epoch_10.ckpt')
with tf.name_scope("net2"):
mySaver.restore(sess, 'epoch_50.ckpt')
这不起作用。
解决此问题的最佳方法是什么?
答案 0 :(得分:12)
最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:
# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
net1 = CreateAlexNet()
saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')
# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
net2 = CreateAlexNet()
saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')
如果由于某种原因这不起作用,并且您必须使用单个tf.Session
(例如,因为您希望在另一个TensorFlow计算中组合来自两个网络的结果),最佳解决方案是:
tf.train.Saver
实例,并使用另一个参数重新映射变量名称。 当constructing存储区时,您可以将字典作为var_list
参数传递,将检查点中的变量名称(即没有名称范围前缀)映射到tf.Variable
您在每个模型中创建的对象。
您可以通过编程方式构建var_list
,并且您应该可以执行以下操作:
with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")
答案 1 :(得分:1)
我遇到了困扰我很久的问题。我在这里找到了一个很好的解决方案:Loading two models from Saver in the same Tensorflow session和TensorFlow checkpoint save and read。
tf.train.Saver()
的默认行为是将每个变量与相应op的名称相关联。这意味着每次构造tf.train.Saver()
时,它都包含先前调用的所有变量。因此,您应该创建不同的图表并与它们运行不同的会话。