我正在尝试使用tf.while_loop
并行化循环。正如建议的here一样,parallel_iterations
参数在渴望模式下没有什么不同。因此,我尝试用tf.while_loop
包装tf.function
。但是,添加装饰器后,迭代变量的行为将发生变化。
例如,这段代码有效。
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
tf.while_loop(c, print_fun, [iteration])
如果添加装饰器,则会发生错误。
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
@tf.function
def run_graph():
iteration = tf.constant(0)
tf.while_loop(c, print_fun, [iteration])
run_graph()
在调试过程中,我发现变量iteration
从张量更改为占位符。这是为什么?我应该如何修改代码以消除错误?
谢谢。
答案 0 :(得分:0)
第一个代码段(不带@tf.function
的代码)利用TensorFlow 2的急切执行直接操作一个numpy数组(即,外部iteration
对象)。使用@tf.function
时,此方法不起作用,因为@ tf.function尝试将您的代码编译为tf.Graph,后者无法直接对numpy数组进行操作(它只能处理tensorflow张量)。要解决此问题,请使用tf.Variable并继续在其分片中分配值。
使用@tf.function
,实际上,您尝试使用@tf.function
的Python到图形自动转换功能(称为AutoGraph),可以用更简单的代码实现目标。您只需编写一个普通的Python while循环(使用tf.less()
代替<
运算符),而While循环将由AutoGraph编译到幕后的tf.while_loop中。
代码类似于:
result = tf.Variable(np.zeros([10], dtype=np.int32))
@tf.function
def run_graph():
i = tf.constant(0, dtype=tf.int32)
while tf.less(i, 10):
result[i].assign(i) # Performance may require tuning here.
i += 1
run_graph()
print(result.read_value())