为什么类对象内的tf.Variable无法初始化?

时间:2019-07-04 03:25:25

标签: python tensorflow

tf.global_variables_initializer()的工作方式使我感到困惑。

此代码以未初始化的错误结束:

import tensorflow as tf

class C(object):
    def __init__(self):
        self.a = tf.Variable(tf.ones(()),tf.float32)

op_init = tf.global_variables_initializer()                                                                                                                                       12
with tf.Session() as sess:

    c = C()

    sess.run(op_init)
    print(sess.run(c.a))
    # -> fail. FailedPreconditionError: Attempting to use uninitialized value Variable

但是另一方面,波纹管起作用了。

import tensorflow as tf
class C(object):

    def __init__(self):
        self.a = tf.Variable(tf.ones(()),tf.float32)
        self.op_init = tf.global_variables_initializer()

with tf.Session() as sess:

    c = C()

    sess.run(c.op_init)
    print(sess.run(c.a))
    # -> success. '1.0'

2 个答案:

答案 0 :(得分:1)

关于初始化程序的“作用域”,另一个答案是错误的。在任何地方创建的 All 变量都会添加到全局变量的 collection 中,这是Tensorflow的概念,与Python变量范围无关。不幸的是,我没有再安装TF 1.X,因此我无法真正进行检查,但是我怀疑第一种情况下的问题是您在创建变量之前先创建了初始化程序。尝试像这样重新排序:

import tensorflow as tf

class C(object):
    def __init__(self):
        self.a = tf.Variable(tf.ones(()),tf.float32)

c = C()
op_init = tf.global_variables_initializer()                                                                                                                                       12
with tf.Session() as sess:
    sess.run(op_init)
    print(sess.run(c.a))

我所做的只是将c的创建移到了初始化程序之前。请让我们知道是否可行!

答案 1 :(得分:0)

Tensor未初始化,因为它不在初始化函数的范围内。第二种方法起作用,因为它在范围内。 tf.global_variables_initializer()为用户带来了便利,但这不是初始化变量的唯一且唯一的方法。

第一个选项的正确方法是传递初始值作为输入:

print(sess.run(c.a, feed_dict={c.a:1}))

这通常是您想要使用tf的方式。变量的内容可能会有所不同,这意味着它可能在每次运行时由用户初始化。