我将使用tf.scan
函数修改其他人共享here的可变长度序列的DRAW(深度循环注意写入器)代码。所以我需要将原始代码中的for循环更改为适合扫描功能的结构。下面是代码的原始部分,
...
for t in range(T):
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
x_hat=x-tf.sigmoid(c_prev) # error image
r=read(x,x_hat,h_dec_prev)
h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev]))
z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
h_dec,dec_state=decode(dec_state,z)
cs[t]=c_prev+write(h_dec) # store results
h_dec_prev=h_dec
DO_SHARE=True # from now on, share variables
...
为了使用tf.scan
,我需要传递几个先前的状态(c_prev
,h_dec_prev
...)。但是,正如我所知tf.scan
elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x: a + x, elems)
仅为循环获得一个张量(是不是?)
a
似乎应该只有一个page.php
,它应该是一个张量。在这种情况下,只有我可以想象的可能方式是展平几个不同的状态张量并将其连接起来。但是我担心它会弄乱代码并使速度慢下来,特别是当状态大小不同时。是否有任何有效(和快速)的方法来处理这类问题?