theano.scan
返回两个变量:values变量和更新变量。例如,
a = theano.shared(1)
values, updates = theano.scan(fn=lambda a:a+1, outputs_info=a, n_steps=10)
但是,我注意到在我使用的大多数示例中,updates变量都是空的。似乎只有当我们以theano.scan
编写函数是某种方式时,我们才会获得更新。例如,
a = theano.shared(1)
values, updates = theano.scan(lambda: {a: a+1}, n_steps=10)
有人可以向我解释为什么在第一个例子中更新是空的,但在第二个例子中,更新变量不为空?更一般地说,theano.scan
中的更新变量如何工作?感谢。
答案 0 :(得分:12)
考虑以下四种变体(可以执行此代码以观察差异)和分析。
import theano
def v1a():
a = theano.shared(1)
outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
f = theano.function([], outputs=outputs)
print f(), a.get_value()
def v1b():
a = theano.shared(1)
outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
f = theano.function([], outputs=outputs, updates=updates)
print f(), a.get_value()
def v2a():
a = theano.shared(1)
outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
f = theano.function([], outputs=outputs)
print f(), a.get_value()
def v2b():
a = theano.shared(1)
outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
f = theano.function([], outputs=outputs, updates=updates)
print f(), a.get_value()
def main():
v1a()
v1b()
v2a()
v2b()
main()
此代码的输出是
[ 2 3 4 5 6 7 8 9 10 11] 1
[ 2 3 4 5 6 7 8 9 10 11] 1
[] 1
[] 11
v1x
版本使用lambda x: x + 1
。 lambda函数的结果是一个符号变量,其值比输入大1。 lambda函数参数的名称已更改,以避免隐藏共享变量名称。在这些变体中,共享变量不会被扫描以任何方式使用或操纵,除了将其用作通过扫描步骤函数递增的循环符号变量的初始值。
v2x
版本使用lambda {a: a + 1}
。 lambda函数的结果是一个字典,解释了如何更新共享变量a
。
updates
变体中的v1x
为空,因为我们尚未从定义任何共享变量更新的步骤函数返回字典。 outputs
变体中的v2x
为空,因为我们没有提供步进函数的任何符号输出。 updates
仅在步骤函数返回共享变量更新表达式字典时使用(如在v2x
中),并且outputs
仅在步骤函数返回符号变量输出时使用(如{ {1}})。
返回字典时,如果未提供给v1x
,则无效。请注意,共享变量尚未在theano.function
中更新,但已在v2a
中更新。
答案 1 :(得分:6)
为了补充Daniel的答案,如果您想同时计算theano扫描中的输出和更新,请查看此示例。
此代码循环遍历序列,计算其元素的总和并更新共享变量t
(句子的长度)
import theano
import numpy as np
t = theano.shared(0)
s = theano.tensor.vector('v')
def rec(s, first, t):
first = s + first
second = s
return (first, second), {t: t+1}
first = np.float32(0)
(firsts, seconds), updates = theano.scan(
fn=rec,
sequences=s,
outputs_info=[first, None],
non_sequences=t)
f = theano.function([s], [firsts, seconds], updates=updates, allow_input_downcast=True)
v = np.arange(10)
print f(v)
print t.get_value()
此代码的输出是
[array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45.], dtype=float32),
array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)]
10
rec
函数输出元组和字典。扫描序列将计算输出并将字典添加到更新中,允许您创建更新t
并同时计算firsts
和seconds
的函数。