Tensorflow 2.0中的tf.function和tf.while循环

时间:2019-12-14 04:50:34

标签: python tensorflow tensorflow2.0

我正在尝试使用tf.while_loop并行化循环。正如建议的here一样,parallel_iterations参数在渴望模式下没有什么不同。因此,我尝试用tf.while_loop包装tf.function。但是,添加装饰器后,迭代变量的行为将发生变化。

例如,这段代码有效。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)
tf.while_loop(c, print_fun, [iteration])

如果添加装饰器,则会发生错误。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)

@tf.function
def run_graph():
    iteration = tf.constant(0)
    tf.while_loop(c, print_fun, [iteration])

run_graph()

在调试过程中,我发现变量iteration从张量更改为占位符。这是为什么?我应该如何修改代码以消除错误?

谢谢。

1 个答案:

答案 0 :(得分:0)

第一个代码段(不带@tf.function的代码)利用TensorFlow 2的急切执行直接操作一个numpy数组(即,外部iteration对象)。使用@tf.function时,此方法不起作用,因为@ tf.function尝试将您的代码编译为tf.Graph,后者无法直接对numpy数组进行操作(它只能处理tensorflow张量)。要解决此问题,请使用tf.Variable并继续在其分片中分配值。

使用@tf.function,实际上,您尝试使用@tf.function的Python到图形自动转换功能(称为AutoGraph),可以用更简单的代码实现目标。您只需编写一个普通的Python while循环(使用tf.less()代替<运算符),而While循环将由AutoGraph编译到幕后的tf.while_loop中。

代码类似于:

result = tf.Variable(np.zeros([10], dtype=np.int32))

@tf.function
def run_graph():
  i = tf.constant(0, dtype=tf.int32)
  while tf.less(i, 10):
    result[i].assign(i)  # Performance may require tuning here.
    i += 1

run_graph()
print(result.read_value())