tf.assign到可变切片在tf.while_loop中不起作用

时间:2018-07-07 18:14:22

标签: python tensorflow

以下代码有什么问题?如果将tf.assign操作应用于循环tf.Variable的一部分,则效果很好。但是,在这种情况下,它给出了以下错误。

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i, a):
    return i < n 

def body(i, a):
    tf.assign(a[i], a[i-1] + a[i-2])
    return i + 1, a

i, b = tf.while_loop(cond, body, [2, a]) 

导致:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3210, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2942, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2879, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/hrbigelow/ai/lb-wavenet/while_var_test.py", line 11, in body
    tf.assign(a[i], a[i-1] + a[i-2])
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 220, in assign
    return ref.assign(value, name=name)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 697, in assign
    raise ValueError("Sliced assignment is only supported for variables")
ValueError: Sliced assignment is only supported for variables

3 个答案:

答案 0 :(得分:4)

您的变量不是循环内运行的操作的输出,它是循环外的外部实体。因此,您不必提供它作为参数。

此外,您需要强制执行更新,例如使用tf.control_dependencies中的body

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i):
    return i < n 

def body(i):
    op = tf.assign(a[i], a[i-1] + a[i-2])
    with tf.control_dependencies([op]):
      return i + 1

i = tf.while_loop(cond, body, [2])

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
i.eval()
print(a.eval())
# [ 1  1  2  3  5  8 13 21 34 55 89]

可能您可能需要谨慎并设置parallel_iterations=1来强制循环按顺序运行。

答案 1 :(得分:0)

从CUDA的角度来看,禁止分配单个索引是有意义的,因为它否定了异构并行计算的所有性能优势。

我知道这会增加一些计算开销,但是它可以工作。

~/.gitconfig

答案 2 :(得分:-1)

我执行了几次,但不一致。但是可变切片确实可以在 while 循环内工作。

试图在 body 内拆分图形,因为有时结果不正确。

有时(但并非总是)返回正确答案(11, array([ 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]))

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a1 = tf.Variable(v, name = 'a')

def cond(i, _):
    return i < n

s = tf.InteractiveSession()
s.run(tf.global_variables_initializer())

def body( i, _):
    x = a1[i-1]
    y = a1[i-2]
    z = tf.add(x,y)
    op = a1[i].assign( z )
    with tf.control_dependencies([op]): #Edit This fixed the inconsistency.
       increment = tf.add(i, 1)
    return increment, op

print(s.run(tf.while_loop(cond, body, [2, a1])))