theano的扫描功能如何工作?

时间:2017-03-13 14:45:13

标签: python theano theano.scan

看看这段代码:

import theano
import numpy
import theano.tensor as T 
import numpy as np

x = T.dvector('x')
y = T.dvector('y')

def fun(x,a):
    return x+a

results, updates = theano.scan(fn=fun,sequences=dict(input=x), outputs_info=dict(initial=y, taps=[-3]))

h = [10.,20,30,40,50,60,70]
f = theano.function([x, y], results)
g = theano.function([y], y)

print(f([1],h))

我已将outputs_info'taps更改为-2,-3等,但代码的结果是相同的[11.0],我无法理解。有人可以解释一下吗?

另一个问题。

import theano
import numpy
import theano.tensor as T 
import numpy as np

x = T.dvector('x')
y = T.dvector('y')

def fun(x,a,b):
    return x+a+b

results, updates = theano.scan(fn=fun,sequences=dict(input=x), outputs_info=dict(initial=y, taps=[-5,-3]))

h = [10.,20,30,40,50,60,70]
f = theano.function([x, y], results)
g = theano.function([y], y)

print(f([1,2,3,4],h))

输出是[41,62,83,85],85是怎么来的?

1 个答案:

答案 0 :(得分:1)

考虑代码的这种变化:

x = T.dvector('x')
y = T.dvector('y')

def fun(x,a,b):
    return x+b

results, updates = theano.scan(
    fn=fun,
    sequences=dict(input=x), 
    outputs_info=dict(initial=y, taps=[-5,-3])
)

h = [10.,20,30,40,50,60,70]
f = theano.function([x, y], results)
g = theano.function([y], y)

print(f([1],h))

您的结果将是31。

  • 将点击更改为[-5, -2],结果更改为41。
  • 将点击更改为[-4, -3],结果更改为21。

这表明事情是如何运作的:

  1. 水龙头中最大的负数被视为h [0]
  2. 所有其他水龙头都偏离
  3. 因此,当点击次数为[-5,-2]时,有趣的输入ab = 10和40。

    更新新问题

    taps实际上表示时间t处的函数取决于函数t - taps时的输出。

    例如,Fibonacci序列由函数

    定义

    f1

    以下是您如何使用theano.scan实施Fibonacci序列:

    x = T.ivector('x')
    y = T.ivector('y')
    
    def fibonacci(x,a,b):
        return a+b
    
    results, _ = theano.scan(
        fn=fibonacci,
        sequences=dict(input=x), 
        outputs_info=dict(initial=y, taps=[-2,-1])
        )
    
    h = [1,1]
    f = theano.function([x, y], results)
    
    print(np.append(h, f(range(10),h)))
    

    但是,theano.scan有问题。如果函数依赖于先前的输出,那么您将使用什么作为第一次迭代的先前输出?

    答案是您的案例中的初始输入h。但是在你的情况下,h比你需要的长,你只需要5个元素(因为在你的情况下最大的抽头是-5)。使用h所需的5个元素后,您的函数将切换到函数的实际输出。

    这里简要介绍了代码中发生的事情:

    1. output[0] = x[0] + h[0] + h[2] = 41
    2. output[1] = x[1] + h[1] + h[3] = 62
    3. output[2] = x[2] + h[2] + h[4] = 83
    4. output[3] = x[3] + h[3] + output[0] = 85
    5. 你会看到,在时间= 4,我们有一个时间4-3的函数输出,并且输出是41.而且由于我们有输出,我们需要使用它,因为函数被定义为使用先前的输出。所以我们忽略h的剩余部分。