如果条件输出

时间:2019-07-24 16:43:14

标签: python tensorflow tensorflow2.0

我正在尝试评估变量a是否为空(即大小== 0)。但是,在使用@ tf.function装饰代码时,if语句错误地将其评估为True,而在删除装饰器时,其将其评估为False。在两种情况下,tf.size(a)似乎正确地计算为0。如何解决呢? 谢谢您的帮助!

import tensorflow as tf
a=tf.Variable([[]])
@tf.function
def test(a):
    print_op = tf.print(tf.size(a))
    print(tf.size(a))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None
test(a)

1 个答案:

答案 0 :(得分:0)

这有点让人头疼,但是,一旦我们了解到tf.function正在将python ops和控制流映射到tf图,而裸函数只是急切地执行,我们就可以对其进行挑选。更有意义。

我已经调整了您的示例以说明正在发生的事情。考虑下面的test1test2

@tf.function
def test1(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

def test2(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

除了@tf.function装饰器之外,它们彼此相同。

现在执行test2(tf.Variable([[]]))会给我们:

0
python print size: 0

这是我假设您期望的行为。而test1(tf.Variable([[]]))给出:

python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0

关于此输出,有几件事(除了fail之外,您可能会感到惊讶:

  • print()语句打印出(尚未评估)张量而不是零。
  • print()tf.print()的顺序已颠倒

这是因为通过添加@tf.function,我们不再具有python函数,而是具有使用autograph从函数代码映射的tf图。这意味着,在评估if条件时,我们尚未执行tf.math.not_equal(tf.size(a),0),而只有一个对象(Tensor对象的实例),在python中,是真的:

class MyClass:
  pass
my_obj = MyClass()
if (my_obj):
  print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"

这意味着我们在评估print('fail')之前先进入test1中的tf.math.not_equal(tf.size(a),0)语句。

那有什么解决办法?

好吧,如果我们在print()块中删除对仅python的if函数的调用,并将其替换为亲笔签名的tf.print()语句,则亲笔签名将无缝转换我们的{图友好型if ... else ...语句的{1}}逻辑可确保一切以正确的顺序发生:

def test3(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        tf.print('fail')
    with tf.control_dependencies([print_op]):
        return None
tf.cond
test3(tf.Variable([[]]))