确定A值是否在TensorFlow中的Set中

时间:2016-01-05 17:54:49

标签: python set tensorflow

tf.logical_ortf.logical_andtf.select函数非常有用。

但是,假设您有价值x,并且您想知道它是否在set(a, b, c, d, e)中。在python中你只需写:

if x in set([a, b, c, d, e]):
  # Do some action.

据我所知,在TensorFlow中执行此操作的唯一方法是嵌套'tf.logical_or'和'tf.equal'。我在下面提供了这个概念的一次迭代:

tf.logical_or(
    tf.logical_or(tf.equal(x, a), tf.equal(x, b)),
    tf.logical_or(tf.equal(x, c), tf.equal(x, d))
)

我觉得在TensorFlow中必须有一种更简单的方法。有吗?

3 个答案:

答案 0 :(得分:3)

看看这个相关的问题:Count number of "True" values in boolean Tensor

您应该能够构建一个由[a,b,c,d,e]组成的张量,然后使用tf.equal(.)

检查是否有任何行等于x

答案 1 :(得分:3)

为了提供更具体的答案,假设您要检查张量x的最后一个维度是否包含1D张量s中的任何值,您可以执行以下操作:

tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))

例如,对于sx

s = tf.constant([3, 4])
x = tf.constant([[[1, 2, 3, 0, 0], 
                  [4, 4, 4, 0, 0]], 
                 [[3, 5, 5, 6, 4], 
                  [4, 7, 3, 8, 9]]])

x的形状为[2, 2, 5],而s的形状为[2] tile_multiples = [1, 1, 1, 2],这意味着我们会将x的最后一个维度平铺2次(对于s中的每个元素一次)沿着新维度。因此,x_tile将如下所示:

[[[[1 1]
   [2 2]
   [3 3]
   [0 0]
   [0 0]]

  [[4 4]
   [4 4]
   [4 4]
   [0 0]
   [0 0]]]

 [[[3 3]
   [5 5]
   [5 5]
   [6 6]
   [4 4]]

  [[4 4]
   [7 7]
   [3 3]
   [8 8]
   [9 9]]]]

x_in_s会将每个平铺值与s中的一个值进行比较。如果任何平铺值在tf.reduce_any中,则最后一个dim的s将返回true,给出最终结果:

[[[False False  True False False]
  [ True  True  True False False]]

 [[ True False False False  True]
  [ True False  True False False]]]

答案 2 :(得分:0)

这是两个解决方案,我们要检查query中是否有whitelist

whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
query = "RESTAURANT"

#use broadcasting for element-wise tensor operation
broadcast_equal = tf.equal(whitelist, query)

#method 1: using tensor ops
broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
broadcast_sum = tf.reduce_sum(broadcast_equal_int)

#method 2: using some tf.core API
nz_cnt = tf.count_nonzero(broadcast_equal)

sess.run([broadcast_equal, broadcast_sum, nz_cnt])
#=> [array([False, False,  True, False]), 1, 1]

因此,如果输出为> 0,则该项目位于集合中。