符号变量在theano中自动更新

时间:2015-01-18 11:11:32

标签: python logistic-regression theano

我遵循给出here的theano教程,用于简单的随机梯度下降。但是,我在这个块中无法理解p_y_given_xy_pred的值是如何根据Wb的值自动更新的运行test_logistic()我们只更新Wb的值? 感谢

class LogisticRegression(object):

    def __init__(self, input, n_in, n_out):
        self.W = theano.shared(
            value=numpy.zeros(
                (n_in, n_out),
                dtype=theano.config.floatX
            ),
            name='W',
            borrow=True
        )
        # initialize the baises b as a vector of n_out 0s
        self.b = theano.shared(
            value=numpy.zeros(
                (n_out,),
                dtype=theano.config.floatX
            ),
            name='b',
            borrow=True
        )
        self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)
        self.y_pred = T.argmax(self.p_y_given_x, axis=1)
        self.params = [self.W, self.b]

    def negative_log_likelihood(self, y):
        return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y])
        # end-snippet-2

    def errors(self, y):
        if y.ndim != self.y_pred.ndim:
            raise TypeError(
                'y should have the same shape as self.y_pred',
                ('y', y.type, 'y_pred', self.y_pred.type)
            )
        # check if y is of the correct datatype
        if y.dtype.startswith('int'):
            # the T.neq operator returns a vector of 0s and 1s, where 1
            # represents a mistake in prediction
            return T.mean(T.neq(self.y_pred, y))
        else:
            raise NotImplementedError() 

1 个答案:

答案 0 :(得分:2)

p_y_given_xy_pred是符号变量(只是来自Theano的python对象)。那些指向Theano对象的python变量没有得到更新。它们只代表我们想要做的计算。想象一下伪代码。

在编译Theano函数时将使用它们。只有这样才能计算出价值。但这不会导致指向对象p_y_given_xy_pred的python变量发生任何变化。对象没有改变。

了解这种区别对某些人来说需要时间。这是一种新的思维方式。所以不要犹豫提问。有一点有用的是总是问自己,你是处于象征世界还是数字世界。数字世界只发生在编译的Theano函数中。