我正在尝试github的张量流代码。但是,我在gibbs采样部分遇到了一个问题。
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
def gibbs_sample(k):
#Runs a k-step gibbs chain to sample from the probability distribution of the RBM defined by W, bh, bv
def gibbs_step(count, k, xk):
#Runs a single gibbs step. The visible values are initialized to xk
hk = sample(tf.sigmoid(tf.matmul(xk, W) + bh)) #Propagate the visible values to sample the hidden values
xk = sample(tf.sigmoid(tf.matmul(hk, tf.transpose(W)) + bv)) #Propagate the hidden values to sample the visible values
return count+1, k, xk
#Run gibbs steps for k iterations
ct = tf.constant(0) #counter
[_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter,
gibbs_step, [ct, tf.constant(k), x], 1, False)
#This is not strictly necessary in this implementation, but if you want to adapt this code to use one of TensorFlow's
#optimizers, you need this in order to stop tensorflow from propagating gradients back through the gibbs step
x_sample = tf.stop_gradient(x_sample)
return x_sample
x = tf.placeholder(tf.float32, [None, 2340], name="x") #The placeholder variable that holds our data
x_sample = gibbs_sample(1)
错误来自control_flow_ops.while_loop
TypeError Traceback (most recent call last)
<ipython-input-11-3bb5ef935182> in <module>()
----> 1 x_sample = gibbs_sample(1)
<ipython-input-2-426df97982ef> in gibbs_sample(k)
10 ct = tf.constant(0) #counter
11 [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter,
---> 12 gibbs_step, [ct, tf.constant(k), x], 1, False)
13 #This is not strictly necessary in this implementation, but if you want to adapt this code to use one of TensorFlow's
14 #optimizers, you need this in order to stop tensorflow from propagating gradients back through the gibbs step
~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
3051 raise TypeError("body must be callable.")
3052 if parallel_iterations < 1:
-> 3053 raise TypeError("parallel_iterations must be a positive integer.")
3054
3055 if maximum_iterations is not None:
TypeError: parallel_iterations must be a positive integer.
根据github的讨论,我知道这个问题与多个并行运行的迭代有关。 https://github.com/tensorflow/tensorflow/issues/1984
while_loop实现非严格语义。迭代可以从开始 一旦这个迭代的操作之一准备就绪(即,它的全部 输入可用。)执行。所以while_loop可以轻松拥有 多个迭代并行运行。例如,对于扫描,甚至 如果累积值在步骤中不可用,则该步骤可以 仍然启动并执行任何不依赖于累积的操作 值。允许多次迭代并行运行的一个问题是 资源管理。 parallel_iterations是为了给用户而引入的 一些内存消耗和执行顺序的控制。
尽管知道它背后的问题,但由于缺乏gibbs采样和control_flow_ops模块,我无法修复代码。任何熟悉gibbs采样和control_flow_ops的人都可以帮我修复gibbs功能吗?
答案 0 :(得分:1)
在您的代码行中
[_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter, gibbs_step, [ct, tf.constant(k), x], 1, False)
你在tf.while_loop调用中最后传递的参数“false”被解释为parallel_iterations参数。我想你打算做以下事情:
[_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter, gibbs_step, [ct, tf.constant(k), x])