如果我要检查tnsr = tf.constant(...)
中的所有元素都大于3,则可以tf.reduce_all(tnsr > 3)
获得标量布尔张量。如果我使用急切执行或@tf.function
,则可以将其用作常规bool
:
@tf.function
def foo(tnsr):
if tf.reduce_all(tnsr > 3):
...
但这是doesn't work with autograph=False
。那我该怎么办?
我尝试过的其他事情:
tf.cond
甚至返回tf.Tensor
的{{1}} true_fn=lambda: True
失败的原因与上述相同