逐渐建立tf.Graph并执行它

时间:2017-06-24 08:55:38

标签: python tensorflow

我正在尝试根据某些条件逐渐构建tf.Graph并在完成追加后运行一次。

代码如下所示:

class Model:
    def __init__(self):
        self.graph = tf.Graph()
        ... some code ...

    def build_initial_graph(self):
        with self.graph.as_default():
            X = tf.placeholder(tf.float32, shape=some_shape)
            ... some code ...

    def add_to_existing_graph(self):
        with self.graph.as_default():
            ... some code adding more ops to the graph ...

    def transform(self, data):
        with tf.Session(graph=self.graph) as session:
             y = session.run(Y, feed_dict={X: data})
        return y

调用方法看起来像这样

model = Model()
model.build_initial_graph()
model.add_to_existing_graph()
model.add_to_existing_graph()
result = model.transform(data)

所以,两个问题

  1. 这种方式对现有图表添加操作是否合法?在不同的地方使用相同的图形对象还是会覆盖旧图形对象?
  2. 在转换方法中,代码运行时无法识别X中的feed_dict,实现该方法的正确方法是什么?

1 个答案:

答案 0 :(得分:2)

Q1:这当然是构建模型的合法方式,但更多的是意见问题。我只建议将您的张量存储为属性(请参阅Q2的回答。)self.X=...

您可以看一下这个very nice post如何以面向对象的方式构建TensorFlow模型。

Q2 :原因很简单,因为变量X不属于transform方法的范围。
如果您执行以下操作,一切都会正常工作:

def build_initial_graph(self):
    with self.graph.as_default():
        self.X = tf.placeholder(tf.float32, shape=some_shape)
        ... some code ...

def transform(self, data):
    with tf.Session(graph=self.graph) as session:
         return session.run(self.Y, feed_dict={self.X: data})

更详细一点,在TensorFlow中,您定义的所有张量或操作(例如tf.placeholdertf.matmul)都在tf.Graph() re working on. You might want to store them in Python variable, as you did by doing X X = tf中定义.placeholder`但这不是强制性的。

如果您想要在您定义的Tensor之一后访问,您可以

  • 使用Python变量(这是你的尝试,除了变量tf.get_variable不在方法的范围内)或,
  • 使用systemctl status name_service.service 方法直接从图表中检索它们(您需要知道它的名称)。