theano.scan的更新如何运作?

时间:2015-10-06 18:17:53

标签: python machine-learning theano

theano.scan返回两个变量:values变量和更新变量。例如,

a = theano.shared(1)

values, updates = theano.scan(fn=lambda a:a+1, outputs_info=a,  n_steps=10)

但是,我注意到在我使用的大多数示例中,updates变量都是空的。似乎只有当我们以theano.scan编写函数是某种方式时,我们才会获得更新。例如,

a = theano.shared(1)

values, updates = theano.scan(lambda: {a: a+1}, n_steps=10)

有人可以向我解释为什么在第一个例子中更新是空的,但在第二个例子中,更新变量不为空?更一般地说,theano.scan中的更新变量如何工作?感谢。

2 个答案:

答案 0 :(得分:12)

考虑以下四种变体(可以执行此代码以观察差异)和分析。

import theano


def v1a():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
    f = theano.function([], outputs=outputs)
    print f(), a.get_value()


def v1b():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda x: x + 1, outputs_info=a, n_steps=10)
    f = theano.function([], outputs=outputs, updates=updates)
    print f(), a.get_value()


def v2a():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
    f = theano.function([], outputs=outputs)
    print f(), a.get_value()


def v2b():
    a = theano.shared(1)
    outputs, updates = theano.scan(lambda: {a: a + 1}, n_steps=10)
    f = theano.function([], outputs=outputs, updates=updates)
    print f(), a.get_value()


def main():
    v1a()
    v1b()
    v2a()
    v2b()


main()

此代码的输出是

[ 2  3  4  5  6  7  8  9 10 11] 1
[ 2  3  4  5  6  7  8  9 10 11] 1
[] 1
[] 11

v1x版本使用lambda x: x + 1。 lambda函数的结果是一个符号变量,其值比输入大1。 lambda函数参数的名称已更改,以避免隐藏共享变量名称。在这些变体中,共享变量不会被扫描以任何方式使用或操纵,除了将其用作通过扫描步骤函数递增的循环符号变量的初始值。

v2x版本使用lambda {a: a + 1}。 lambda函数的结果是一个字典,解释了如何更新共享变量a

updates变体中的v1x为空,因为我们尚未从定义任何共享变量更新的步骤函数返回字典。 outputs变体中的v2x为空,因为我们没有提供步进函数的任何符号输出。 updates仅在步骤函数返回共享变量更新表达式字典时使用(如在v2x中),并且outputs仅在步骤函数返回符号变量输出时使用(如{ {1}})。

返回字典时,如果未提供给v1x,则无效。请注意,共享变量尚未在theano.function中更新,但已在v2a中更新。

答案 1 :(得分:6)

为了补充Daniel的答案,如果您想同时计算theano扫描中的输出和更新,请查看此示例。

此代码循环遍历序列,计算其元素的总和并更新共享变量t(句子的长度)

import theano
import numpy as np

t = theano.shared(0)
s = theano.tensor.vector('v')

def rec(s, first, t):
    first = s + first
    second = s
    return (first, second), {t: t+1}

first = np.float32(0)

(firsts, seconds), updates = theano.scan(
    fn=rec,
    sequences=s,
    outputs_info=[first, None],
    non_sequences=t)

f = theano.function([s], [firsts, seconds], updates=updates, allow_input_downcast=True)

v = np.arange(10)

print f(v)
print t.get_value()

此代码的输出是

[array([  0.,   1.,   3.,   6.,  10.,  15.,  21.,  28.,  36.,  45.], dtype=float32), 
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.], dtype=float32)]
10

rec函数输出元组和字典。扫描序列将计算输出并将字典添加到更新中,允许您创建更新t并同时计算firstsseconds的函数。