我有一个一维tf.uint8
张量x
,并想断言该张量内的所有值都在我定义的集合s
中。 s
是在图定义时固定的,因此它不是动态计算的张量。
在普通Python中,我想做某事。类似于以下内容:
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
assert all(el in s for el in x), "This should fail, as 5 is not in s"
我知道我可以在断言部分使用tf.Assert
,但是我在定义条件部分(el in s
)时很费劲。最简单/最规范的方法是什么?
Determining if A Value is in a Set in TensorFlow中有2.5年历史的答案对我来说还不够:首先,写下和理解非常复杂,其次,它使用广播的tf.equal
,这在计算上更昂贵而不是适当的基于集合的检查。
答案 0 :(得分:2)
一种简单的方法可能是这样的:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# Check every value in x against every value in s
xs_eq = tf.equal(x_t[:, tf.newaxis], s_t)
# Check every element in x is equal to at least one element in s
assert_op = tf.Assert(tf.reduce_all(tf.reduce_any(xs_eq, axis=1)), [x_t])
with tf.control_dependencies([assert_op]):
# Use x_t...
这将创建一个中间张量,大小为(len(x), len(s))
。如果这有问题,您也可以将问题分解为独立的张量,例如:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
# Count where each x matches each s
x_in_s = [tf.cast(tf.equal(x_t, si), tf.int32) for si in s]
# Add matches and check there is at least one match per x
assert_op = tf.Assert(tf.reduce_all(tf.add_n(x_in_s) > 0), [x_t])
编辑:
实际上,由于您说的值是tf.uint8
,因此可以使用布尔数组使事情变得更好:
import tensorflow as tf
x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# One-hot vectors of values included in x and s
x_bool = tf.scatter_nd(tf.cast(x_t[:, tf.newaxis], tf.int32),
tf.ones_like(x_t, dtype=tf.bool), [256])
s_bool = tf.scatter_nd(tf.cast(s_t[:, tf.newaxis], tf.int32),
tf.ones_like(s_t, dtype=tf.bool), [256])
# Check that all values in x are in s
assert_op = tf.Assert(tf.reduce_all(tf.equal(x_bool, x_bool & s_bool)), [x_t])
这需要线性时间和恒定的内存。
编辑2:虽然从理论上讲,最后一种方法在这种情况下是最好的,但做几个快速基准测试后,当我处理多达数十万个元素时,我只能看到性能上的显着差异,无论如何,这三种情况下tf.uint8
仍然相当快。