tf.concat()将渐变0?

时间:2018-03-12 18:36:21

标签: python tensorflow

有没有办法用tf.scan实现断纸过程?具体来说,我有Theano的代码,我无法将此代码转换为Tensorflow。我尝试使用标准循环和张量切片和连接,但我认为它无法在优化中获得渐变。

    stick_segment = theano.shared(value=np.zeros((batch_size,), dtype=theano.config.floatX), name='stick_segment')
    remaining_stick = theano.shared(value=np.ones((batch_size,), dtype=theano.config.floatX), name='remaining_stick')

    def compute_latent_vars(i, stick_segment, remaining_stick, v_samples):
        # compute stick segment                                                                                                     
        stick_segment = v_samples[:,i] * remaining_stick
        remaining_stick *= (1-v_samples[:,i])
        return (stick_segment, remaining_stick)

    (stick_segments, remaining_sticks), updates = theano.scan(fn=compute_latent_vars,
                                                              outputs_info=[stick_segment, remaining_stick],sequences=T.arange(latent_size-1),
                                                              non_sequences=[v_samples], strict=True)

我能够将Theano代码转换为Tensorflow,但我无法获得渐变。即,代码是:

def fn(previous_output,current_input):
    [stick,remaining] = previous_output
    i = current_input
    stick = v[:,i]*remaining
    remaining *= (1-v[:,i])
    return [stick,remaining]

elems = tf.Variable(tf.range(latent-1))
[pi,rem] = tf.scan(fn,elems,initializer=[tf.ones([bs]),tf.ones([bs])])
z = tf.concat([pi,tf.reshape(rem[-1,:],[1,bs])],axis=0)
z = tf.transpose(z)
delt = tf.gradients(z,v)
return z,delt

但是当我看到pi或rem的渐变时,它们是非零的。但是当我看到z的渐变时,它们全都变为0.是不是因为tf.concat?有什么方法可以防止这个问题吗?

0 个答案:

没有答案