如何使用theano.scan()处理空序列?

时间:2015-12-22 10:25:23

标签: python theano

代码段如下:

t = tensor.arange(1, K)
results, updates = theano.scan(fn=updatefunc, sequences=t, ...)

扫描过程将沿着t迭代。但是,当K <= 1时,t将是一个空范围,那么theano.scan()将崩溃。有什么方法可以解决这个问题吗?

1 个答案:

答案 0 :(得分:1)

只有当序列中包含一些元素时,才可以使用theano.ifelse.ifelse来计算扫描。例如:

import theano
import theano.tensor as tt
import theano.ifelse


def step(x_t, s_tm1):
    return s_tm1 + x_t


def compile():
    K = tt.lscalar()
    t = tt.arange(1, K)
    zero = tt.constant(0, dtype='int64')
    outputs, _ = theano.scan(step, sequences=[t], outputs_info=[zero])
    output = theano.ifelse.ifelse(tt.gt(K, 1), outputs[-1], zero)
    return theano.function([K], outputs=[output])


def main():
    f = compile()
    print f(3)
    print f(2)
    print f(1)
    print f(0)
    print f(-1)


main()

打印

[array(3L, dtype=int64)]
[array(1L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]