我正在从Lecture note 2: TensorFlow Ops了解TensorFlow。一切都很好,直到我在Note的结尾遇到了“延迟加载的陷阱”。我试图重新编写演示“延迟加载”的脚本,如下所示:
import tensorflow as tf
x = tf.Variable(10, name='x')
y = tf.Variable(20, name='y')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(10):
sess.run(tf.add(x, y))
print (sess.run(tf.get_default_graph().as_graph_def()))
和结果:
...
node {
name: "Add_8"
op: "Add"
input: "x/read"
input: "y/read"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "Add_9"
op: "Add"
input: "x/read"
input: "y/read"
attr {
key: "T"
value {
type: DT_INT32
}
}
注释说明: “ 有两种方法可以避免此错误。首先,请尽可能将操作的定义及其执行分开。但是当由于您希望将相关操作分组为类而无法执行操作时,可以使用Python属性,以确保函数在首次调用时仅加载一次。 ” 我想为上面的脚本应用Python属性,以避免陷入延迟加载的陷阱。 请帮助我。
答案 0 :(得分:1)
第二种方法假定您要将相关的操作分组到类中。因此,您的代码等效于以下内容:
import tensorflow as tf
class Test():
def __init__(self):
self.x = tf.Variable(10, name='x')
self.y = tf.Variable(20, name='y')
@property
def add(self):
self._value = tf.add(self.x,self.y)
return self._value
with tf.Session() as sess:
test = Test()
sess.run(tf.global_variables_initializer())
for _ in range(10):
sess.run(test.add)
print(tf.get_default_graph().as_graph_def())
# writer = tf.summary.FileWriter("tensorboard_model",tf.get_default_graph())
# writer.close()
# print
...
node {
name: "Add_9"
op: "Add"
input: "x/read"
input: "y/read"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
...
但是我们可以使用行为类似于@property
的{{3}}来确保函数在首次调用时仅加载一次。以下代码来自custom decorators。
import functools
def lazy_property(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
我们可以使用它:
@lazy_property
def add(self):
self._value = tf.add(self.x,self.y)
return self._value
再次运行将产生以下结果。 Structuring Your TensorFlow Models