尝试将NN定义为类时初始化tf.variable时出错

时间:2018-01-17 17:27:19

标签: python tensorflow

我尝试使用python类定义一个简单的张量流图,如下所示:

import numpy as np
import tensorflow as tf

class NNclass:

def __init__(self, state_d, action_d, state):
    self.s_dim = state_d
    self.a_dim = action_d
    self.state = state
    self.prediction

@property
def prediction(self):
    a = tf.constant(5, dtype=tf.float32)
    w1 = tf.Variable(np.random.normal(0, 1))
    return tf.add(a, w1)

state = tf.placeholder(tf.float64, shape=[None, 1])
NN_instance = NNclass(1, 2, state)

ses = tf.Session()
ses.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(ses.run(NN_instance.prediction,  feed_dict={state: nn_input}))

当我运行此代码时,我收到以下错误:

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_1

我看到它的方式,我有一个 NNclass 的实例,我查看了tf图,因为 def__init __ 超过了预测方法。 但我不明白为什么运行它会产生上述错误。 请帮忙 感谢

2 个答案:

答案 0 :(得分:4)

创建所有变量后应调用

tf.global_variables_initializer()。在您的示例中,prediction函数定义w1变量,该变量在ses.run()之前未初始化。

您可以在__init__函数中创建变量,如下所示:

class NNclass:
    def __init__(self, state_d, action_d, state):
        self.s_dim = state_d
        self.a_dim = action_d
        self.state = state
        self.a = tf.constant(5, dtype=tf.float32)
        self.w1 = tf.Variable(np.random.normal(0, 1))

    @property
    def prediction(self):
        return tf.add(self.a, self.w1)

答案 1 :(得分:3)

将函数的结果传递给sess.run()并不是最佳做法,这会引起混淆。

配置网络的更好做法是创建build_graph()函数,其中定义了所有tensorflow操作。然后返回您需要计算的张量(更好的是,将它们存储在字典中或将它们保存为对象的属性)。

示例:

def build_graph():
  a = tf.constant(5, dtype=tf.float32)
  w1 = tf.Variable(np.random.normal(0, 1))
  a_plus_w = tf.add(a, w1)
  state = tf.placeholder(tf.float64, shape=[None, 1])
  return a_plus_w, state

a_plus_w, state = build_graph()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

nn_input = np.array([[0.5], [0.7]])
print(sess.run(a_plus_w,  feed_dict={state: nn_input}))

您所犯的关键错误是您没有在tensorflow中分离开发的两个阶段。你有一个“构建图”阶段,你定义你想要执行的所有数学运算,然后你有一个“执行”阶段在哪里使用sess.run来询问tensorflow为您执行计算。当你调用sess.run时,你需要传递tensorflow你想要计算的张量(已经在图中定义的tf对象)。你不应该通过tensorflow传递一个函数来执行。