当使用tf.cond()时,tensorflow会为不必要的占位符请求输入

时间:2018-02-13 02:56:07

标签: python tensorflow machine-learning

请考虑以下包含张量流tf.cond()的代码段。

    import tensorflow as tf
    import numpy as np

    bb = tf.placeholder(tf.bool)
    xx = tf.placeholder(tf.float32, name='xx')
    yy = tf.placeholder(tf.float32, name='yy')

    zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

    with tf.Session() as sess:
            dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
            print(sess.run(zz, feed_dict=dict1)) # works fine without errors

            dict2 = {bb:False, yy:np.array([1., 3, 4])}
            print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
                                                 # provide an input for xx

在这两种情况下,bb都是Falsezz的评估理论上不依赖xx,但仍然需要输入xx张量流。即使它可以作为虚拟数组提供,但它必须与yy的形状匹配,并且不像dict2那样干净。

有人可以建议如何评估zz(使用tf.cond()或任何其他方法)而不提供xx的值吗?

1 个答案:

答案 0 :(得分:8)

您可以将xx定义为tf.Variable,并为其指定一个默认值(只要xx没有提供其他值,就会使用该值)。有几点需要注意:

  1. 虽然xx不是占位符,但您仍然可以通过feed_dict将值添加到其中来对待它。
  2. 使用validate_shape=False,以便将任何形状投放到xx
  3. 使用trainable=False以便xx未优化(否则,优化程序可能会将其默认值更改为Nan,这可能会导致问题。)
  4. 请勿忘记使用例如xx初始化tf.global_variables_initializer()的值。
  5. 以下是代码:

    import tensorflow as tf
    import numpy as np
    
    bb = tf.placeholder(tf.bool)
    xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
    yy = tf.placeholder(tf.float32, name='yy')
    
    zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
    
    with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
       print(sess.run(zz, feed_dict=dict1))
       dict2 = {bb:False, yy:np.array([1., 3, 4])}
       print(sess.run(zz, feed_dict=dict2))