我正在尝试理解以下的theano代码。
self.sgd_step = theano.function(
[x, y, learning_rate, theano.Param(decay, default=0.9)],
[],
updates=[(E, E - learning_rate * dE / T.sqrt(mE + 1e-6)),
(U, U - learning_rate * dU / T.sqrt(mU + 1e-6)),
(W, W - learning_rate * dW / T.sqrt(mW + 1e-6)),
(V, V - learning_rate * dV / T.sqrt(mV + 1e-6)),
(b, b - learning_rate * db / T.sqrt(mb + 1e-6)),
(c, c - learning_rate * dc / T.sqrt(mc + 1e-6)),
(self.mE, mE),
(self.mU, mU),
(self.mW, mW),
(self.mV, mV),
(self.mb, mb),
(self.mc, mc)
])
有人可以告诉我,上面代码的作者试图在那里做什么?有一个值[x, y, learning_rate, theano.Param(decay, default=0.9)]
正在尝试更新,价值将由[]
更新?这里updates
的功能是什么?
如果我能够了解上述代码中的内容,我将非常感激吗?
答案 0 :(得分:2)
updates
的文档如下(取自here)。
必须为更新提供表单对的列表(共享变量,新表达式)。它也可以是一个字典,其键是共享变量,值是新表达式。无论哪种方式,它都意味着“无论何时运行此函数,它都将用相应表达式的结果替换每个共享变量的.value”。在上面,我们的累加器用状态和增量的总和替换状态的值。
因此,当您使用所需输入调用上述theano函数时,它将更新共享变量的值,即E, U, W, V, b, c, ..., self.mc
。要更新的新值由元组中的第二个数量给出。基本上,E = E - learning_rate * dE / T.sqrt(mE + 1e-6)
等等。