我正在尝试评估变量a是否为空(即大小== 0)。但是,在使用@ tf.function装饰代码时,if语句错误地将其评估为True,而在删除装饰器时,其将其评估为False。在两种情况下,tf.size(a)似乎正确地计算为0。如何解决呢? 谢谢您的帮助!
import tensorflow as tf
a=tf.Variable([[]])
@tf.function
def test(a):
print_op = tf.print(tf.size(a))
print(tf.size(a))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
test(a)
答案 0 :(得分:0)
这有点让人头疼,但是,一旦我们了解到tf.function
正在将python ops和控制流映射到tf图,而裸函数只是急切地执行,我们就可以对其进行挑选。更有意义。
我已经调整了您的示例以说明正在发生的事情。考虑下面的test1
和test2
:
@tf.function
def test1(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
def test2(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
除了@tf.function
装饰器之外,它们彼此相同。
现在执行test2(tf.Variable([[]]))
会给我们:
0
python print size: 0
这是我假设您期望的行为。而test1(tf.Variable([[]]))
给出:
python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0
关于此输出,有几件事(除了fail
之外,您可能会感到惊讶:
print()
语句打印出(尚未评估)张量而不是零。print()
和tf.print()
的顺序已颠倒这是因为通过添加@tf.function
,我们不再具有python函数,而是具有使用autograph从函数代码映射的tf图。这意味着,在评估if
条件时,我们尚未执行tf.math.not_equal(tf.size(a),0)
,而只有一个对象(Tensor
对象的实例),在python中,是真的:
class MyClass:
pass
my_obj = MyClass()
if (my_obj):
print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"
这意味着我们在评估print('fail')
之前先进入test1
中的tf.math.not_equal(tf.size(a),0)
语句。
那有什么解决办法?
好吧,如果我们在print()
块中删除对仅python的if
函数的调用,并将其替换为亲笔签名的tf.print()
语句,则亲笔签名将无缝转换我们的{图友好型if ... else ...
语句的{1}}逻辑可确保一切以正确的顺序发生:
def test3(a): print_op = tf.print(tf.size(a)) print("python print size: {}".format(tf.size(a))) if tf.math.not_equal(tf.size(a),0): tf.print('fail') with tf.control_dependencies([print_op]): return None
tf.cond
test3(tf.Variable([[]]))