Gibbs采样错误:parallel_iterations必须是正整数

时间:2018-06-07 06:45:08

标签: python-3.x tensorflow

我正在尝试github的张量流代码。但是,我在gibbs采样部分遇到了一个问题。

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

def gibbs_sample(k):
    #Runs a k-step gibbs chain to sample from the probability distribution of the RBM defined by W, bh, bv
    def gibbs_step(count, k, xk):
        #Runs a single gibbs step. The visible values are initialized to xk
        hk = sample(tf.sigmoid(tf.matmul(xk, W) + bh)) #Propagate the visible values to sample the hidden values
        xk = sample(tf.sigmoid(tf.matmul(hk, tf.transpose(W)) + bv)) #Propagate the hidden values to sample the visible values
        return count+1, k, xk

    #Run gibbs steps for k iterations
    ct = tf.constant(0) #counter
    [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter,
                                         gibbs_step, [ct, tf.constant(k), x], 1, False)
    #This is not strictly necessary in this implementation, but if you want to adapt this code to use one of TensorFlow's
    #optimizers, you need this in order to stop tensorflow from propagating gradients back through the gibbs step
    x_sample = tf.stop_gradient(x_sample) 
    return x_sample

x  = tf.placeholder(tf.float32, [None, 2340], name="x") #The placeholder variable that holds our data
x_sample = gibbs_sample(1) 

错误来自control_flow_ops.while_loop

TypeError                                 Traceback (most recent call last)
<ipython-input-11-3bb5ef935182> in <module>()
----> 1 x_sample = gibbs_sample(1)

<ipython-input-2-426df97982ef> in gibbs_sample(k)
     10     ct = tf.constant(0) #counter
     11     [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter,
---> 12                                          gibbs_step, [ct, tf.constant(k), x], 1, False)
     13     #This is not strictly necessary in this implementation, but if you want to adapt this code to use one of TensorFlow's
     14     #optimizers, you need this in order to stop tensorflow from propagating gradients back through the gibbs step

~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
   3051       raise TypeError("body must be callable.")
   3052     if parallel_iterations < 1:
-> 3053       raise TypeError("parallel_iterations must be a positive integer.")
   3054 
   3055     if maximum_iterations is not None:

TypeError: parallel_iterations must be a positive integer.

根据github的讨论,我知道这个问题与多个并行运行的迭代有关。 https://github.com/tensorflow/tensorflow/issues/1984

  

while_loop实现非严格语义。迭代可以从开始   一旦这个迭代的操作之一准备就绪(即,它的全部   输入可用。)执行。所以while_loop可以轻松拥有   多个迭代并行运行。例如,对于扫描,甚至   如果累积值在步骤中不可用,则该步骤可以   仍然启动并执行任何不依赖于累积的操作   值。允许多次迭代并行运行的一个问题是   资源管理。 parallel_iterations是为了给用户而引入的   一些内存消耗和执行顺序的控制。

尽管知道它背后的问题,但由于缺乏gibbs采样和control_flow_ops模块,我无法修复代码。任何熟悉gibbs采样和control_flow_ops的人都可以帮我修复gibbs功能吗?

1 个答案:

答案 0 :(得分:1)

在您的代码行中

    [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter, gibbs_step, [ct, tf.constant(k), x], 1, False)

你在tf.while_loop调用中最后传递的参数“false”被解释为parallel_iterations参数。我想你打算做以下事情:

    [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter, gibbs_step, [ct, tf.constant(k), x])