TensorFlow中的延迟加载实现

时间:2019-05-11 02:55:43

标签: python tensorflow

我正在从Lecture note 2: TensorFlow Ops了解TensorFlow。一切都很好,直到我在Note的结尾遇到了“延迟加载的陷阱”。我试图重新编写演示“延迟加载”的脚本,如下所示:

import tensorflow as tf

x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(10):
        sess.run(tf.add(x, y))
    print (sess.run(tf.get_default_graph().as_graph_def()))

和结果:

...
node {
  name: "Add_8"
  op: "Add"
  input: "x/read"
  input: "y/read"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Add_9"
  op: "Add"
  input: "x/read"
  input: "y/read"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }

注释说明: “ 有两种方法可以避免此错误。首先,请尽可能将操作的定义及其执行分开。但是当由于您希望将相关操作分组为类而无法执行操作时,可以使用Python属性,以确保函数在首次调用时仅加载一次。 ” 我想为上面的脚本应用Python属性,以避免陷入延迟加载的陷阱。 请帮助我。

1 个答案:

答案 0 :(得分:1)

第二种方法假定您要将相关的操作分组到类中。因此,您的代码等效于以下内容:

import tensorflow as tf

class Test():
    def __init__(self):
        self.x = tf.Variable(10, name='x')
        self.y = tf.Variable(20, name='y')

    @property
    def add(self):
        self._value = tf.add(self.x,self.y)
        return self._value

with tf.Session() as sess:
    test = Test()
    sess.run(tf.global_variables_initializer())
    for _ in range(10):
        sess.run(test.add)
    print(tf.get_default_graph().as_graph_def())
    # writer = tf.summary.FileWriter("tensorboard_model",tf.get_default_graph())
    # writer.close()

# print
...
node {
  name: "Add_9"
  op: "Add"
  input: "x/read"
  input: "y/read"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
...

您可以使用tensorboard查看操作节点,如下所示: enter image description here

但是我们可以使用行为类似于@property的{​​{3}}来确保函数在首次调用时仅加载一次。以下代码来自custom decorators

import functools
def lazy_property(function):
    attribute = '_cache_' + function.__name__
    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)
    return decorator

我们可以使用它:

@lazy_property
def add(self):
    self._value = tf.add(self.x,self.y)
    return self._value

再次运行将产生以下结果。 Structuring Your TensorFlow Models