如何在张量流中正确实现延迟加载?

时间:2017-11-03 14:01:12

标签: python tensorflow

以下代码(尝试在https://danijar.com/structuring-your-tensorflow-models/中复制代码结构时)

4.30E+38

给出错误。其中一部分如下:

'%f' % client_id

如何解决此问题?

1 个答案:

答案 0 :(得分:3)

问题是,sess.run(tf.global_variables_initializer())的调用在创建变量之前发生了,在第一次调用model.output后的行中。

要解决此问题,您必须在调用model.output之前以某种方式访问​​sess.run(tf.global_variables_initializer())。例如,以下代码有效:

import tensorflow as tf

class Model:

    def __init__(self, x):
        self.x = x
        self._output = None

    @property
    def output(self):
        # NOTE: You must use `if self._output is None` when `self._output` can
        # be a tensor, because `if self._output` on a tensor object will raise
        # an exception.
        if self._output is None:
            weight = tf.Variable(tf.constant(4.0))
            bias = tf.Variable(tf.constant(2.0))
            self._output = tf.multiply(self.x, weight) + bias
        return self._output

def main():
    x = tf.placeholder(tf.float32)
    model = Model(x)

    # The variables are created on this line.
    output_t = model.output

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        output = sess.run(output_t, {x: 4.0})
        print(output)

if __name__ == '__main__':
    main()