theano.function更新的功能

时间:2017-05-06 23:53:31

标签: theano

我正在尝试理解以下的theano代码。

self.sgd_step = theano.function(
            [x, y, learning_rate, theano.Param(decay, default=0.9)],
            [], 
            updates=[(E, E - learning_rate * dE / T.sqrt(mE + 1e-6)),
                     (U, U - learning_rate * dU / T.sqrt(mU + 1e-6)),
                     (W, W - learning_rate * dW / T.sqrt(mW + 1e-6)),
                     (V, V - learning_rate * dV / T.sqrt(mV + 1e-6)),
                     (b, b - learning_rate * db / T.sqrt(mb + 1e-6)),
                     (c, c - learning_rate * dc / T.sqrt(mc + 1e-6)),
                     (self.mE, mE),
                     (self.mU, mU),
                     (self.mW, mW),
                     (self.mV, mV),
                     (self.mb, mb),
                     (self.mc, mc)
])

有人可以告诉我,上面代码的作者试图在那里做什么?有一个值[x, y, learning_rate, theano.Param(decay, default=0.9)]正在尝试更新,价值将由[]更新?这里updates的功能是什么?

如果我能够了解上述代码中的内容,我将非常感激吗?

1 个答案:

答案 0 :(得分:2)

updates的文档如下(取自here)。

  必须为

更新提供表单对的列表(共享变量,新表达式)。它也可以是一个字典,其键是共享变量,值是新表达式。无论哪种方式,它都意味着“无论何时运行此函数,它都将用相应表达式的结果替换每个共享变量的.value”。在上面,我们的累加器用状态和增量的总和替换状态的值。

因此,当您使用所需输入调用上述theano函数时,它将更新共享变量的值,即E, U, W, V, b, c, ..., self.mc。要更新的新值由元组中的第二个数量给出。基本上,E = E - learning_rate * dE / T.sqrt(mE + 1e-6)等等。