代码段如下:
t = tensor.arange(1, K)
results, updates = theano.scan(fn=updatefunc, sequences=t, ...)
扫描过程将沿着t迭代。但是,当K <= 1时,t将是一个空范围,那么theano.scan()将崩溃。有什么方法可以解决这个问题吗?
答案 0 :(得分:1)
只有当序列中包含一些元素时,才可以使用theano.ifelse.ifelse
来计算扫描。例如:
import theano
import theano.tensor as tt
import theano.ifelse
def step(x_t, s_tm1):
return s_tm1 + x_t
def compile():
K = tt.lscalar()
t = tt.arange(1, K)
zero = tt.constant(0, dtype='int64')
outputs, _ = theano.scan(step, sequences=[t], outputs_info=[zero])
output = theano.ifelse.ifelse(tt.gt(K, 1), outputs[-1], zero)
return theano.function([K], outputs=[output])
def main():
f = compile()
print f(3)
print f(2)
print f(1)
print f(0)
print f(-1)
main()
打印
[array(3L, dtype=int64)]
[array(1L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]