尝试在Tensorflow中初始化神经网络,收到TypeError

时间:2019-04-30 01:18:04

标签: python python-3.x tensorflow

我在Tensorflow中正确初始化我的神经网络时遇到麻烦。

在我的BayesianNN类中,我有一个build_graph函数:

def build_graph():
 self._create_feedforward() 
 self._initializer()
 self._define_layers()
 self._regularization()

截至目前,我的_create_feedforward()设置了我希望如何初始化权重和偏差以及其输出的框架:

def _create_feedforward(self, input, output, scope):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
        self.weights = tf.get_variable('weights',
                                       shape=(input.shape[1], output),
                                       initializer=tf.random_normal_initializer(0,0.01),
                                       dtype=tf.float32)
        self.biases = tf.get_variable('biases',
                                      shape=(output),
                                      initializer=tf.constant_initializer(0.0),
                                      dtype=tf.float32)

        activation = tf.matmul(input, self.weights) + self.biases

        return tf.matmul(tf.diag(self.Bern_prob.sample((input.shape[1],))), 
                                  activation) 

虽然我的_define_layers函数设置了网络的输入参数:

def _define_layers(self):
    layer_1_output = _create_feedforward(model_X, self.layer_1_dim, 'layer_1')
    layer_2_output = _create_feedforward(self.layer_1_dim, self.layer_2_dim, 'layer_2')
    layer_3_output = _create_feedforward(self.layer_2_dim, [1], 'layer_3')

因为在_define_layers之前调用_create_feedforward,所以我收到的typeError的输入量不足。但是我不能先调用_define_layers,因为尚未定义_create_feedforward。

  

TypeError:_create_feedforward()缺少3个必需的位置   参数:“输入”,“输出”和“作用域”

我知道为什么会发生此错误,但是如何在代码中整洁且没有错误地实现它呢?

1 个答案:

答案 0 :(得分:1)

这里的问题是,在_create_feedforward中,您拥有参数self。但是,在_define_layers中调用该函数时,不会将该函数作为类的一部分来调用。试试这个:

def _define_layers(self):
    layer_1_output = self._create_feedforward(model_X, self.layer_1_dim, 'layer_1')
    layer_2_output = self._create_feedforward(self.layer_1_dim, self.layer_2_dim,'layer_2')
    layer_3_output = self._create_feedforward(self.layer_2_dim, [1], 'layer_3')