有没有办法用tf.scan实现断纸过程?具体来说,我有Theano的代码,我无法将此代码转换为Tensorflow。我尝试使用标准循环和张量切片和连接,但我认为它无法在优化中获得渐变。
stick_segment = theano.shared(value=np.zeros((batch_size,), dtype=theano.config.floatX), name='stick_segment')
remaining_stick = theano.shared(value=np.ones((batch_size,), dtype=theano.config.floatX), name='remaining_stick')
def compute_latent_vars(i, stick_segment, remaining_stick, v_samples):
# compute stick segment
stick_segment = v_samples[:,i] * remaining_stick
remaining_stick *= (1-v_samples[:,i])
return (stick_segment, remaining_stick)
(stick_segments, remaining_sticks), updates = theano.scan(fn=compute_latent_vars,
outputs_info=[stick_segment, remaining_stick],sequences=T.arange(latent_size-1),
non_sequences=[v_samples], strict=True)
我能够将Theano代码转换为Tensorflow,但我无法获得渐变。即,代码是:
def fn(previous_output,current_input):
[stick,remaining] = previous_output
i = current_input
stick = v[:,i]*remaining
remaining *= (1-v[:,i])
return [stick,remaining]
elems = tf.Variable(tf.range(latent-1))
[pi,rem] = tf.scan(fn,elems,initializer=[tf.ones([bs]),tf.ones([bs])])
z = tf.concat([pi,tf.reshape(rem[-1,:],[1,bs])],axis=0)
z = tf.transpose(z)
delt = tf.gradients(z,v)
return z,delt
但是当我看到pi或rem的渐变时,它们是非零的。但是当我看到z的渐变时,它们全都变为0.是不是因为tf.concat?有什么方法可以防止这个问题吗?