我想使用theano.scan函数将代码更改为theano中的循环形式,但是结果却不是我所想的。我定义了一个名为tau_mu的全局变量,我认为这可能会导致问题,但是我无法弄清楚。
I want to change the below code to the form of loop:
(1)
import theano
import theano.tensor as TT
t=np.arange(10)
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu1=TT.switch(tau[0]>=t,mu[0],mu[1])
tau_mu2=TT.switch(tau[1]>=t,tau_mu1,mu[2])
tau_mu3=TT.switch(tau[2]>=t,tau_mu2,mu[3])
f = theano.function([tau,mu], tau_mu3)
f([3,5,7],[2,4,5,6])
结果: array([2。,2.,2.,2.,4.,4.,5.,5.,6.,6.,6.,6.,6.,6.,6.,6。])< / p>
(2)the following is the form of loop
tau = TT.vector('tau')
mu=TT.vector('mu')
tau_mu=TT.vector('tau_mu')
tau_mu=TT.switch(tau[0]>=t,mu[0],mu[1])
indc=TT.ivector('indc')
def one_step(indc,tau,mu):
global tau_mu
tau_mu=TT.switch(tau[indc]>=t,tau_mu,mu[indc+1])
return tau_mu
result,updates=theano.scan(fn=one_step,sequences=[indc],non_sequences=[tau,mu])
f = theano.function([indc,tau,mu], result)
f([1,2],[3,5,7],[2,4,5,6])
结果:
array([[2., 2., 2., 2., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
[2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]])
最终结果是
array([2., 2., 2., 2., 4., 4., 4., 4., 6., 6., 6., 6., 6., 6., 6.]]),
我的预期结果是
array([2., 2., 2., 2., 4., 4., 5., 5., 6., 6., 6., 6., 6., 6., 6.])
我应该在哪里更改代码? 预先感谢您提供的任何指导。