我已经训练了一个模型并将其保存在检查点中,但只是意识到我忘记在恢复模型时命名一个我想要检查的变量。
我知道如何从tensorflow,(g = tf.get_default_graph()
然后g.get_tensor_by_name([name])
)检索命名变量。在这种情况下,我知道它的范围,但它是未命名的。我试过查看tf.GraphKeys.GLOBAL_VARIABLES
,但由于某种原因,它没有出现在那里。
以下是它在模型中的定义:
with tf.name_scope("contrastive_loss") as scope:
l2_dist = tf.cast(tf.sqrt(1e-4 + tf.reduce_sum(tf.subtract(pred_left, pred_right), 1)), tf.float32) # the variable I want
# I use it here when calculating another named tensor, if that helps.
con_loss = contrastive_loss(l2_dist)
loss = tf.reduce_sum(con_loss, name="loss")
有没有找到没有名字的变量的方法?
答案 0 :(得分:7)
首先,跟进我的第一条评论,有意义的是tf.get_collection
给定名称范围不起作用。从the documentation开始,如果提供范围,则仅返回具有指定名称的变量或操作。那就好了。
您可以尝试的一件事是列出Graph
中每个节点的名称:
print([node.name for node in tf.get_default_graph().as_graph_def().node])
或者,从检查点恢复时可能:
saver = tf.train.import_meta_graph(/path/to/meta/graph)
sess = tf.Session()
saver.resore(sess, /path/to/checkpoints)
graph = sess.graph
print([node.name for node in graph.as_graph_def().node])
另一种选择是使用tensorboard或Jupyter Notebook和show_graph
命令显示图形。现在可能有一个内置的show_graph
,但该链接指向一个定义了一个的git存储库。然后,您必须在图表中搜索您的操作,然后可能使用以下命令检索它:
my_op = tf.get_collection('full_operation_name')[0]
如果您希望将来进行设置以便按名称检索,则需要使用tf.add_to_collection
将其添加到集合中:
my_op = tf.some_operation(stuff, name='my_op')
tf.add_to_collection('my_op_name', my_op)
然后通过恢复图表然后使用:
来检索它my_restored_op = tf.get_collection('my_op_name')[0]
您也可以通过命名它然后在tf.get_collection
中指定其范围来获得,但我不确定。可以找到更多信息和有用的教程here。
答案 1 :(得分:1)
tf.get_collection不适用于未命名的变量。因此,列出以下操作:
graph = sess.graph
print(graph.get_operations())
...在列表中找到你的张量然后:
global_step_tensor = graph.get_tensor_by_name('complete_operation_name:0')
我发现这个tutorial非常有助于理解这些背后的机制。