Tensorflow:在类中创建图表并在

时间:2016-06-12 05:04:53

标签: python class session graph tensorflow

我相信我很难理解图表在张量流中如何工作以及如何访问它们。我的直觉是用图表下的线条:'将图形形成为单个实体。因此,我决定创建一个在实例化时构建图形的类,并且拥有一个运行图形的函数,如下所示;

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             prediction = ... 
             cost       = ...
             optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(optimizer, feed_dict)
             loss = sess.run(cost, feed_dict)
             ...
        return variables

接下来的步骤是创建一个主文件,它将汇集参数传递给类,构建图形然后运行它;

#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }

#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...

这对我来说非常优雅,但它并不是很有效(显然)。实际上,似乎launchG函数无法访问图中定义的节点,这给出了我的错误;

---> 26 sess.run(optimizer, feed_dict)

NameError: name 'optimizer' is not defined

也许这是我的python(和tensorflow)理解太有限了,但我的奇怪印象是,在创建图形(G)的情况下,使用此图形作为参数运行会话应该可以访问节点中的节点它,不要求我提供明确的访问权限。

任何启示?

1 个答案:

答案 0 :(得分:14)

节点predictioncostoptimizer是在方法__init__中创建的局部变量,无法在方法launchG中访问它们。< / p>

最简单的解决方法是将它们声明为类Graph的属性:

class Graph(object):

    #To build the graph when instantiated
    def __init__(self, parameters ):
        self.graph = tf.Graph()
        with self.graph.as_default():
             ...
             self.prediction = ... 
             self.cost       = ...
             self.optimizer  = ...
             ...
    # To launch the graph
    def launchG(self, inputs):
        with tf.Session(graph=self.graph) as sess:
             ...
             sess.run(self.optimizer, feed_dict)
             loss = sess.run(self.cost, feed_dict)
             ...
        return variables

您还可以使用graph.get_tensor_by_namegraph.get_operation_by_name的确切名称检索图表的节点。