安全检查tf.Tensor是否满足某些条件

时间:2020-06-16 16:12:53

标签: python tensorflow

如果我要检查tnsr = tf.constant(...)中的所有元素都大于3,则可以tf.reduce_all(tnsr > 3)获得标量布尔张量。如果我使用急切执行或@tf.function,则可以将其用作常规bool

@tf.function
def foo(tnsr):
    if tf.reduce_all(tnsr > 3):
        ...

但这是doesn't work with autograph=False。那我该怎么办?

我尝试过的其他事情:

  • tf.cond甚至返回tf.Tensor的{​​{1}}
  • true_fn=lambda: True失败的原因与上述相同

0 个答案:

没有答案