如何在Tensorflow的全局命名空间中存储和检索Tensors?

时间:2018-07-07 01:28:53

标签: tensorflow

我正在尝试阅读一个大型Tensorflow项目。对于一个计算图的节点分散在该项目周围的项目,我想知道是否存在一种方法来存储计算图的Tensor节点并将该节点添加到sess.run中的获取列表中?

例如,如果我想在以下项目的第615行添加概率 https://github.com/allenai/document-qa/blob/master/docqa/nn/span_prediction.py到全局名称空间,有没有类似tf.add_node(probs,“ probs”)的方法,后来我可以得到tf.get_node(“ probs”),只是为了方便在节点周围传递节点项目。

一个更普遍的问题是,构造张量流代码并提高使用不同模型进行实验的效率的更好的主意是什么?

1 个答案:

答案 0 :(得分:1)

当然可以。要在以后检索它,您必须给它起一个名字,以便您可以按名称检索它。以代码中的probs为例。它是使用tf.nn.softmax()函数创建的,其API如下所示。

tf.nn.softmax(
    logits,
    axis=None,
    name=None,
    dim=None
)

看到参数name吗?您可以像下面这样将此参数添加到615行:

 probs = tf.nn.softmax(all_logits, name='my_tensor')

以后需要时,可以调用tf.Graph.get_tensor_by_name(name)来检索此张量。

graph = tf.get_default_graph()
retrieved_probs = graph.get_tensor_by_name('my_tensor:0')

'my_tensor' softmax 操作的名称,应在其末尾添加“:0”,这意味着您要获取张量而不是该操作。调用Graph.get_operation_by_name()时,不应添加“:0”。

您必须确保张量存在(它可以在此行之前执行的代码中创建,或者可以从元图文件中恢复)。如果是在可变范围内创建的,则还必须在name参数的前面添加范围名称和一个'/'。例如,'my_scope/my_tensor:0'