tf.cond:将元素添加到列表

时间:2018-10-05 09:57:47

标签: python tensorflow

我需要一个非常简单的图形中的条件控制流,该图形能够将元素添加到列表中。就我而言,该列表应使用tf.Variable([])声明。

我收到一个奇怪的错误:UnboundLocalError: local variable 'list1' referenced before assignment

这是一个玩具示例:

pred = tf.placeholder(tf.bool, shape=[])
list1 = tf.Variable([])

def f1():
  e1 = tf.constant(1.0)
  list1 = tf.concat([list1, [e1]], 0)

def f2():
  e2 = tf.constant(2.0)
  list1 = tf.concat([list1, [e2]], 0)

y = tf.cond(pred, f1, f2)
with tf.Session() as session:
  session.run(tf.global_variables_initializer())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

1 个答案:

答案 0 :(得分:1)

tf.get_variable获取具有这些参数的现有变量或创建一个新变量。

所以这种模式可以避免您遇到的问题。

此外,我在关闭形状验证后还必须使用此行,因为这会导致验证错误。

       list1 = tf.assign( list1, tf.concat([list1, [e1]], 0), validate_shape=False)

工作代码是这个。

   def f1():
        with tf.variable_scope("reuse", reuse=tf.AUTO_REUSE):
            list1 = tf.get_variable(initializer=[], dtype=tf.float32, name='list1')
            e1 = tf.constant(1.)
            print(tf.shape(list1))
            list1 = tf.assign( list1, tf.concat([list1, [e1]], 0), validate_shape=False)
            return list1


    def f2():
        with tf.variable_scope("reuse", reuse=tf.AUTO_REUSE):
            list1 = tf.get_variable(  dtype=tf.float32, name='list1')
            e2 = tf.constant(2.)
            list1 = tf.assign( list1, tf.concat([list1, [e2]], 0), validate_shape=False)
            return list1


    y = tf.cond(pred, f1,  f2)

    with tf.Session() as session:

      with tf.variable_scope("reuse", reuse=tf.AUTO_REUSE):
          session.run(tf.global_variables_initializer())
          print(session.run([y], feed_dict={pred: False}))  # ==> [1]
          print(session.run([y], feed_dict={pred: True}))  # ==> [2]

          # Gets the updated variable
          list1 = tf.get_variable(dtype=tf.float32, name='list1')
          print(session.run(list1))

此打印

[array([2.], dtype=float32)]
[array([2., 1.], dtype=float32)]
[2. 1.]